Skip to content

Commit

Permalink
Replace using_pjrt() xla runtime device_type() check with in xla.…
Browse files Browse the repository at this point in the history
…py for `torch-xla>=2.5` (#20442)

* Replace `using_pjrt()` xla runtime `device_type()` check with in xla.py

Fixes #20419

`torch_xla.runtime.using_pjrt()` is removed in pytorch/xla#7787

This PR replaces references to that function with a check to [`device_type()`](https://github.com/pytorch/xla/blob/master/torch_xla/runtime.py#L83) to recreate the behavior of that function, minus the manual initialization

* Added tests/refactored for version compat

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* precommit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
byi8220 and pre-commit-ci[bot] authored Nov 25, 2024
1 parent 1e88899 commit 0afd4e1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,21 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
# PJRT support requires this minimum version
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")


def _using_pjrt() -> bool:
# `using_pjrt` is removed in torch_xla 2.5
if _XLA_GREATER_EQUAL_2_5:
from torch_xla import runtime as xr

return xr.device_type() is not None
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as xr

return xr.using_pjrt()

from torch_xla.experimental import pjrt

return pjrt.using_pjrt()
Expand Down
5 changes: 5 additions & 0 deletions tests/tests_fabric/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def test_get_parallel_devices_raises(tpu_available):
XLAAccelerator.get_parallel_devices(5)
with pytest.raises(ValueError, match="Could not parse.*anything-else'"):
XLAAccelerator.get_parallel_devices("anything-else")


@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
def test_instantiate_xla_accelerator():
_ = XLAAccelerator()

0 comments on commit 0afd4e1

Please sign in to comment.