diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index c24b912d8..d139b352c 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -759,6 +759,13 @@ def __init__( self._feature_table_map: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -800,8 +807,16 @@ def init_parameters(self) -> None: def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. @@ -1213,6 +1228,13 @@ def __init__( self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as # `ShardedEmbeddingTable`. @@ -1265,15 +1287,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: ), ): return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), ) else: return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), + indices=( + features.values() + if self._embedding_table_index_type == torch.int32 + else features.values().long() + ), + offsets=( + features.offsets().type(dtype=features.values().dtype) + if self._embedding_table_offset_type == torch.int32 + else features.offsets().long() + ), per_sample_weights=weights, )