From f52fd32e127c59a29994bf21ea506131020d5b88 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 21 Jan 2025 14:39:30 -0800 Subject: [PATCH] change dtype of block_bucketize_row_pos and fix flaky test_kjt_bucketize_before_all2all_cpu (#2689) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2689 # context * found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461 * the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function ``` block_bucketize_pos=( _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) if block_bucketize_row_pos is not None else None ), ``` where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor ``` torch.fx.wrap def _fx_wrap_tensor_to_device_dtype( t: torch.Tensor, tensor_device_dtype: torch.Tensor ) -> torch.Tensor: return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) ``` * the fix is supposed to be straightforward to apply a list-comprehension over the function ``` block_bucketize_pos=( [ _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths() for pos in block_bucketize_row_pos ] ``` * according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error {F1974430883} * according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`. length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`. Reviewed By: dstaay-fb Differential Revision: D68358894 fbshipit-source-id: 13303c54288c99c6cf58d550365f8d3c698c34b1 --- torchrec/distributed/embedding_sharding.py | 5 +- torchrec/distributed/tests/test_utils.py | 113 +++------------------ 2 files changed, 18 insertions(+), 100 deletions(-) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 38bb0dd4b..0f37e71a1 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all( batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), block_bucketize_pos=( - _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) + [ + _fx_wrap_tensor_to_device_dtype(pos, kjt.values()) + for pos in block_bucketize_row_pos + ] if block_bucketize_row_pos is not None else None ), diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 3e299192e..bdffcf7a0 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -263,98 +263,6 @@ def block_bucketize_ref( class KJTBucketizeTest(unittest.TestCase): - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - # pyre-ignore[56] - @given( - index_type=st.sampled_from([torch.int, torch.long]), - offset_type=st.sampled_from([torch.int, torch.long]), - world_size=st.integers(1, 129), - num_features=st.integers(1, 15), - batch_size=st.integers(1, 15), - ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_kjt_bucketize_before_all2all( - self, - index_type: torch.dtype, - offset_type: torch.dtype, - world_size: int, - num_features: int, - batch_size: int, - ) -> None: - MAX_BATCH_SIZE = 15 - MAX_LENGTH = 10 - # max number of rows needed for a given feature to have unique row index - MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE - - lengths_list = [ - random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size) - ] - keys_list = [f"feature_{i}" for i in range(num_features)] - # for each feature, generate unrepeated row indices - indices_lists = [ - random.sample( - range(MAX_ROW_COUNT), - # number of indices needed is the length sum of all batches for a feature - sum( - lengths_list[ - feature_offset * batch_size : (feature_offset + 1) * batch_size - ] - ), - ) - for feature_offset in range(num_features) - ] - indices_list = list(itertools.chain(*indices_lists)) - - weights_list = [random.randint(1, 100) for _ in range(len(indices_list))] - - # for each feature, calculate the minimum block size needed to - # distribute all rows to the available trainers - block_sizes_list = [ - ( - math.ceil((max(feature_indices_list) + 1) / world_size) - if feature_indices_list - else 1 - ) - for feature_indices_list in indices_lists - ] - - kjt = KeyedJaggedTensor( - keys=keys_list, - lengths=torch.tensor(lengths_list, dtype=offset_type) - .view(num_features * batch_size) - .cuda(), - values=torch.tensor(indices_list, dtype=index_type).cuda(), - weights=torch.tensor(weights_list, dtype=torch.float).cuda(), - ) - """ - each entry in block_sizes identifies how many hashes for each feature goes - to every rank; we have three featues in `self.features` - """ - block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() - - block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt=kjt, - num_buckets=world_size, - block_sizes=block_sizes, - ) - - expected_block_bucketized_kjt = block_bucketize_ref( - kjt, - world_size, - block_sizes, - ) - - self.assertTrue( - keyed_jagged_tensor_equals( - block_bucketized_kjt, - expected_block_bucketized_kjt, - is_pooled_features=True, - ) - ) - # pyre-ignore[56] @given( index_type=st.sampled_from([torch.int, torch.long]), @@ -363,9 +271,12 @@ def test_kjt_bucketize_before_all2all( num_features=st.integers(1, 15), batch_size=st.integers(1, 15), variable_bucket_pos=st.booleans(), + device=st.sampled_from( + ["cpu"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_kjt_bucketize_before_all2all_cpu( + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_kjt_bucketize_before_all2all( self, index_type: torch.dtype, offset_type: torch.dtype, @@ -373,6 +284,7 @@ def test_kjt_bucketize_before_all2all_cpu( num_features: int, batch_size: int, variable_bucket_pos: bool, + device: str, ) -> None: MAX_BATCH_SIZE = 15 MAX_LENGTH = 10 @@ -423,17 +335,17 @@ def test_kjt_bucketize_before_all2all_cpu( kjt = KeyedJaggedTensor( keys=keys_list, - lengths=torch.tensor(lengths_list, dtype=offset_type).view( + lengths=torch.tensor(lengths_list, dtype=offset_type, device=device).view( num_features * batch_size ), - values=torch.tensor(indices_list, dtype=index_type), - weights=torch.tensor(weights_list, dtype=torch.float), + values=torch.tensor(indices_list, dtype=index_type, device=device), + weights=torch.tensor(weights_list, dtype=torch.float, device=device), ) """ each entry in block_sizes identifies how many hashes for each feature goes to every rank; we have three featues in `self.features` """ - block_sizes = torch.tensor(block_sizes_list, dtype=index_type) + block_sizes = torch.tensor(block_sizes_list, dtype=index_type, device=device) block_bucketized_kjt, _ = bucketize_kjt_before_all2all( kjt=kjt, num_buckets=world_size, @@ -442,7 +354,10 @@ def test_kjt_bucketize_before_all2all_cpu( ) expected_block_bucketized_kjt = block_bucketize_ref( - kjt, world_size, block_sizes, "cpu" + kjt, + world_size, + block_sizes, + device, ) self.assertTrue(