Skip to content

Commit

Permalink
Remove List[int] as input type for Trainer when accelerator="cpu" (
Browse files Browse the repository at this point in the history
…#20399)

Co-authored-by: Alan Chu <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent e1b172c commit 20d19d2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: Union[int, str]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand All @@ -72,12 +72,12 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
)


def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int:
def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int:
"""Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
Args:
cpu_cores: An int > 0.
cpu_cores: An int > 0 or a string that can be converted to an int > 0.
Returns:
An int representing the number of processes
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def teardown(self) -> None:

@staticmethod
@override
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: Union[int, str]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)

@staticmethod
@override
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand Down

0 comments on commit 20d19d2

Please sign in to comment.