From 6f423d0bb284190cf1b12d8a943a334e57b4df28 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Feb 2025 09:43:56 -0800 Subject: [PATCH] update pin (#8677) --- WORKSPACE | 3 ++- openxla_patches/count_down.diff | 14 ++++++++++++++ setup.py | 7 ++----- test/test_core_aten_ops.py | 5 ++++- torch_xla/csrc/dl_convertor.cpp | 2 +- torch_xla/csrc/runtime/BUILD | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 6 ++++-- 7 files changed, 28 insertions(+), 10 deletions(-) create mode 100644 openxla_patches/count_down.diff diff --git a/WORKSPACE b/WORKSPACE index dfd9f4b3221..43f3cc1ce9f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -46,7 +46,7 @@ new_local_repository( # To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to # the openxla git commit hash. -xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d' +xla_hash = '52d5ccaf00fdbc32956c457eae415c09f56f0208' http_archive( name = "xla", @@ -57,6 +57,7 @@ http_archive( patch_tool = "patch", patches = [ "//openxla_patches:gpu_race_condition.diff", + "//openxla_patches:count_down.diff", ], strip_prefix = "xla-" + xla_hash, urls = [ diff --git a/openxla_patches/count_down.diff b/openxla_patches/count_down.diff new file mode 100644 index 00000000000..b46d3907752 --- /dev/null +++ b/openxla_patches/count_down.diff @@ -0,0 +1,14 @@ +diff --git a/xla/backends/cpu/runtime/convolution_thunk_internal.h b/xla/backends/cpu/runtime/convolution_thunk_internal.h +index 84fed6bb78..9835f12e4e 100644 +--- a/xla/backends/cpu/runtime/convolution_thunk_internal.h ++++ b/xla/backends/cpu/runtime/convolution_thunk_internal.h +@@ -342,7 +342,8 @@ void EigenGenericConv2D( + Eigen::Index start = task_index * task_size; + Eigen::Index end = std::min(start + task_size, feature_group_count); + for (Eigen::Index i = start; i < end; ++i) { +- auto on_done = [count_down]() mutable { count_down.CountDown(); }; ++ // auto on_done = [count_down]() mutable { count_down.CountDown(); }; ++ auto on_done = [count_down]() mutable { const_cast(&count_down)->CountDown(); }; + auto [output, convolved] = convolve_group(i); + output.device(device, std::move(on_done)) = convolved; + } diff --git a/setup.py b/setup.py index 66e7f151183..ed20fb67748 100644 --- a/setup.py +++ b/setup.py @@ -66,10 +66,10 @@ USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax -_date = '20250131' +_date = '20250210' # Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details. -_libtpu_version = '0.0.9' +_libtpu_version = '0.0.10' _jax_version = '0.5.1' _jaxlib_version = '0.5.1' @@ -332,9 +332,6 @@ def run(self): 'tpu': [ f'libtpu=={_libtpu_version}', 'tpu-info', - # This special version removes `libtpu.so` from any `libtpu-nightly` installations, - # since we have migrated to using the `libtpu.so` from the `libtpu` package. - "libtpu-nightly==0.1.dev20241010+nightly.cleanup" ], # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'], diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 591d2d18a4c..6e2ac67e4f5 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -993,7 +993,10 @@ def test_aten_convolution_1(self): 1, ) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs) + # With xla pin to 52d5ccaf00fdbc32956c457eae415c09f56f0208 + # The rtol needs to be raise to 1e-3 on CPU. + run_export_and_compare( + self, torch.ops.aten.convolution, args, kwargs, rtol=1e-3) def test_aten_convolution_2(self): args = ( diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index 2f174a4af22..efb8121784f 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -329,7 +329,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { device->client()->CreateViewOfDeviceBuffer( static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset, - shape, device, on_delete_callback); + shape, *device->default_memory_space(), on_delete_callback); XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer."; XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null."; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 0e5b714cd1e..9f89156d864 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -165,6 +165,7 @@ cc_library( ":tf_logging", "@tsl//tsl/platform:stacktrace", "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:macros", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 8caad6d230f..749419f66cd 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -274,7 +274,9 @@ std::vector PjRtComputationClient::TransferToDevice( tensor->dimensions(), tensor->byte_strides(), xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes, - [tensor]() { /* frees tensor */ }, pjrt_device) + [tensor]() { /* frees tensor */ }, + *pjrt_device->default_memory_space(), + /*device_layout=*/nullptr) .value()); ComputationClient::DataPtr data = @@ -321,7 +323,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( // Returns error if the buffer is already on `dst_device`. absl::StatusOr> status_or = - pjrt_data->buffer->CopyToDevice(dst_device); + pjrt_data->buffer->CopyToMemorySpace(*dst_device->default_memory_space()); if (!status_or.ok()) { return data; }