Skip to content

Commit

Permalink
update pin (#8677)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 authored Feb 13, 2025
1 parent 42edbe1 commit 6f423d0
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 10 deletions.
3 changes: 2 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ new_local_repository(

# To build PyTorch/XLA with OpenXLA to a new revision, update following xla_hash to
# the openxla git commit hash.
xla_hash = '6e91ff19dad528ab7d2025a9bb46150618a3bc7d'
xla_hash = '52d5ccaf00fdbc32956c457eae415c09f56f0208'

http_archive(
name = "xla",
Expand All @@ -57,6 +57,7 @@ http_archive(
patch_tool = "patch",
patches = [
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:count_down.diff",
],
strip_prefix = "xla-" + xla_hash,
urls = [
Expand Down
14 changes: 14 additions & 0 deletions openxla_patches/count_down.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/xla/backends/cpu/runtime/convolution_thunk_internal.h b/xla/backends/cpu/runtime/convolution_thunk_internal.h
index 84fed6bb78..9835f12e4e 100644
--- a/xla/backends/cpu/runtime/convolution_thunk_internal.h
+++ b/xla/backends/cpu/runtime/convolution_thunk_internal.h
@@ -342,7 +342,8 @@ void EigenGenericConv2D(
Eigen::Index start = task_index * task_size;
Eigen::Index end = std::min(start + task_size, feature_group_count);
for (Eigen::Index i = start; i < end; ++i) {
- auto on_done = [count_down]() mutable { count_down.CountDown(); };
+ // auto on_done = [count_down]() mutable { count_down.CountDown(); };
+ auto on_done = [count_down]() mutable { const_cast<decltype(count_down)*>(&count_down)->CountDown(); };
auto [output, convolved] = convolve_group(i);
output.device(device, std::move(on_done)) = convolved;
}
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@

USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax

_date = '20250131'
_date = '20250210'

# Note: jax/jaxlib 20250115 build will fail. Check https://github.com/pytorch/xla/pull/8621#issuecomment-2616564634 for more details.
_libtpu_version = '0.0.9'
_libtpu_version = '0.0.10'
_jax_version = '0.5.1'
_jaxlib_version = '0.5.1'

Expand Down Expand Up @@ -332,9 +332,6 @@ def run(self):
'tpu': [
f'libtpu=={_libtpu_version}',
'tpu-info',
# This special version removes `libtpu.so` from any `libtpu-nightly` installations,
# since we have migrated to using the `libtpu.so` from the `libtpu` package.
"libtpu-nightly==0.1.dev20241010+nightly.cleanup"
],
# pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
'pallas': [f'jaxlib=={_jaxlib_version}', f'jax=={_jax_version}'],
Expand Down
5 changes: 4 additions & 1 deletion test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,10 @@ def test_aten_convolution_1(self):
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.convolution, args, kwargs)
# With xla pin to 52d5ccaf00fdbc32956c457eae415c09f56f0208
# The rtol needs to be raise to 1e-3 on CPU.
run_export_and_compare(
self, torch.ops.aten.convolution, args, kwargs, rtol=1e-3)

def test_aten_convolution_2(self):
args = (
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
device->client()->CreateViewOfDeviceBuffer(
static_cast<char*>(dlmt->dl_tensor.data) +
dlmt->dl_tensor.byte_offset,
shape, device, on_delete_callback);
shape, *device->default_memory_space(), on_delete_callback);
XLA_CHECK_OK(pjrt_buffer.status()) << "Failed to create a pjrt buffer.";
XLA_CHECK(pjrt_buffer.value() != nullptr) << "pjrt buffer is null.";

Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ cc_library(
":tf_logging",
"@tsl//tsl/platform:stacktrace",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:macros",
],
)

Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToDevice(
tensor->dimensions(), tensor->byte_strides(),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[tensor]() { /* frees tensor */ }, pjrt_device)
[tensor]() { /* frees tensor */ },
*pjrt_device->default_memory_space(),
/*device_layout=*/nullptr)
.value());

ComputationClient::DataPtr data =
Expand Down Expand Up @@ -321,7 +323,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(

// Returns error if the buffer is already on `dst_device`.
absl::StatusOr<std::unique_ptr<xla::PjRtBuffer>> status_or =
pjrt_data->buffer->CopyToDevice(dst_device);
pjrt_data->buffer->CopyToMemorySpace(*dst_device->default_memory_space());
if (!status_or.ok()) {
return data;
}
Expand Down

0 comments on commit 6f423d0

Please sign in to comment.