From 0afd4e1375d168cdca797e45eb3a83b8e1314c4b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Mon, 25 Nov 2024 05:47:19 -0500 Subject: [PATCH] Replace `using_pjrt()` xla runtime `device_type()` check with in xla.py for `torch-xla>=2.5` (#20442) * Replace `using_pjrt()` xla runtime `device_type()` check with in xla.py Fixes https://github.com/Lightning-AI/pytorch-lightning/issues/20419 `torch_xla.runtime.using_pjrt()` is removed in https://github.com/pytorch/xla/pull/7787 This PR replaces references to that function with a check to [`device_type()`](https://github.com/pytorch/xla/blob/master/torch_xla/runtime.py#L83) to recreate the behavior of that function, minus the manual initialization * Added tests/refactored for version compat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/fabric/accelerators/xla.py | 7 +++++++ tests/tests_fabric/accelerators/test_xla.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 4a1f25a91062b..d438197329939 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -102,14 +102,21 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No # PJRT support requires this minimum version _XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") _XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") +_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") def _using_pjrt() -> bool: + # `using_pjrt` is removed in torch_xla 2.5 + if _XLA_GREATER_EQUAL_2_5: + from torch_xla import runtime as xr + + return xr.device_type() is not None # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. if _XLA_GREATER_EQUAL_2_1: from torch_xla import runtime as xr return xr.using_pjrt() + from torch_xla.experimental import pjrt return pjrt.using_pjrt() diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 1af7d7e1e7206..7a906c8ae0c54 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -44,3 +44,8 @@ def test_get_parallel_devices_raises(tpu_available): XLAAccelerator.get_parallel_devices(5) with pytest.raises(ValueError, match="Could not parse.*anything-else'"): XLAAccelerator.get_parallel_devices("anything-else") + + +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") +def test_instantiate_xla_accelerator(): + _ = XLAAccelerator()