Skip to content

Commit

Permalink
Update batched_embedding_kernel
Browse files Browse the repository at this point in the history
Summary:
After this diff stack:

EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params.

These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py

Differential Revision: D66919716
  • Loading branch information
basilwong authored and facebook-github-bot committed Jan 24, 2025
1 parent 4d7b7ff commit 294a6fc
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,13 @@ def __init__(
self._feature_table_map: List[int] = []
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
self._fused_params: Dict[str, Any] = config.fused_params or {}
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
"embedding_table_index_type", torch.int64
)
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
"embedding_table_offset_type", torch.int64
)

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
Expand Down Expand Up @@ -800,8 +807,16 @@ def init_parameters(self) -> None:

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
indices=(
features.values()
if self._embedding_table_index_type == torch.int32
else features.values().long()
),
offsets=(
features.offsets().type(dtype=features.values().dtype)
if self._embedding_table_offset_type == torch.int32
else features.offsets().long()
),
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
Expand Down Expand Up @@ -1213,6 +1228,13 @@ def __init__(
self._lengths_per_emb: List[int] = []
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
self._fused_params: Dict[str, Any] = config.fused_params or {}
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
"embedding_table_index_type", torch.int64
)
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
"embedding_table_offset_type", torch.int64
)

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
Expand Down Expand Up @@ -1265,15 +1287,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
),
):
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
indices=(
features.values()
if self._embedding_table_index_type == torch.int32
else features.values().long()
),
offsets=(
features.offsets().type(dtype=features.values().dtype)
if self._embedding_table_offset_type == torch.int32
else features.offsets().long()
),
per_sample_weights=weights,
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
)
else:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
indices=(
features.values()
if self._embedding_table_index_type == torch.int32
else features.values().long()
),
offsets=(
features.offsets().type(dtype=features.values().dtype)
if self._embedding_table_offset_type == torch.int32
else features.offsets().long()
),
per_sample_weights=weights,
)

Expand Down

0 comments on commit 294a6fc

Please sign in to comment.