From 294a6fc3a253cd386dbd4365bb2f0170213dfd3c Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Fri, 24 Jan 2025 01:59:39 -0800 Subject: [PATCH] Update batched_embedding_kernel Summary: After this diff stack: EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params. These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py Differential Revision: D66919716 --- .../distributed/batched_embedding_kernel.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index c24b912d8..d139b352c 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -759,6 +759,13 @@ def __init__( self._feature_table_map: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -800,8 +807,16 @@ def init_parameters(self) -> None: def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. @@ -1213,6 +1228,13 @@ def __init__( self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -1265,15 +1287,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: ), ): return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), ) else: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, )