From 1ec47dce441357cfc5d95b156faea28883a9d1fc Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Thu, 30 Jan 2025 17:33:28 -0800 Subject: [PATCH] Update batched_embedding_kernel (#2702) Summary: After this diff stack: EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params. These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py Differential Revision: D66919716 --- .../distributed/batched_embedding_kernel.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b4db21da4..1a1509b3b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -764,6 +764,13 @@ def __init__( self._feature_table_map: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -805,8 +812,16 @@ def init_parameters(self) -> None: def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. @@ -1218,6 +1233,13 @@ def __init__( self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -1270,15 +1292,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: ), ): return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), ) else: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, )