Skip to content

Commit

Permalink
Add device property to lazy load functionality (#20183)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 9, 2024
1 parent 828fd99 commit 1551a16
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/lightning/fabric/utilities/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __torch_function__(
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
return func(*loaded_args, **kwargs)

@property
def device(self) -> torch.device:
return torch.device(self.storageinfo[3])

def __getattr__(self, name: str) -> Any:
# These properties don't require materialization and can be accessed through the meta tensor directly
if name in {
Expand All @@ -160,7 +164,7 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.metatensor, name)

# materializing these is needed for quantization (see lit-gpt)
if name in {"contiguous", "cuda", "half", "data"}:
if name in {"contiguous", "cuda", "half", "data", "to"}:
return getattr(self._load_tensor(), name)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/utilities/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_lazy_load_module(tmp_path):
model1.load_state_dict(checkpoint)

assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
assert checkpoint["weight"].device == torch.device("cpu")
assert type(checkpoint["weight"].to("cpu")) is torch.Tensor
assert type(model0.weight.data) is torch.Tensor
assert torch.equal(model0.weight, model1.weight)
assert torch.equal(model0.bias, model1.bias)
Expand Down

0 comments on commit 1551a16

Please sign in to comment.