diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py index d243fc255..2fbcc0a66 100644 --- a/torchrec/sparse/tests/test_tensor_dict.py +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -17,14 +17,14 @@ from torchrec.sparse.tensor_dict import maybe_td_to_kjt -class TestTensorDIct(unittest.TestCase): - @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) +class TestTensorDict(unittest.TestCase): # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", + @given( + device_str=st.sampled_from( + ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ) ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) def test_kjt_input(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) @@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None: features = maybe_td_to_kjt(kjt) self.assertEqual(features, kjt) - @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", + @given( + device_str=st.sampled_from( + ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ) ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) def test_td_kjt(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)