Skip to content

Commit

Permalink
2025-02-09 nightly release (1afbf08)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 9, 2025
1 parent dc6796a commit 8e11829
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 16 deletions.
13 changes: 11 additions & 2 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import torch
from tensordict import TensorDict
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
Expand Down Expand Up @@ -90,6 +91,7 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1198,8 +1200,15 @@ def _compute_sequence_vbe_context(
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: KeyedJaggedTensor,
features: TypeUnion[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
need_permute: bool = True
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._features_order:
feature_keys = [feature_keys[i] for i in self._features_order]
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
Expand All @@ -1209,7 +1218,7 @@ def input_dist(
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if self._features_order:
if need_permute and self._features_order:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
Expand Down
16 changes: 12 additions & 4 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
Expand Down Expand Up @@ -94,6 +95,7 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -656,9 +658,7 @@ def __init__(
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
PoolingType.MEAN.value in self._pooling_type_to_rs_features
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
Expand Down Expand Up @@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices(

# pyre-ignore [14]
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
self,
ctx: EmbeddingBagCollectionContext,
features: Union[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if len(self._features_order) > 0:
feature_keys = [feature_keys[i] for i in self._features_order]
self._has_features_permute = False # feature_keys are in order
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if self._has_uninitialized_input_dist:
Expand Down
32 changes: 26 additions & 6 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def gen_model_and_input(
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -177,9 +178,9 @@ def gen_model_and_input(
feature_processor_modules=feature_processor_modules,
)
inputs = []
for _ in range(num_inputs):
inputs.append(
(
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
for _ in range(num_inputs):
inputs.append(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
Expand All @@ -188,8 +189,26 @@ def gen_model_and_input(
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
)
elif generate == ModelInput.generate:
for _ in range(num_inputs):
inputs.append(
ModelInput.generate(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
)
)
else:
for _ in range(num_inputs):
inputs.append(
cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
Expand All @@ -200,7 +219,6 @@ def gen_model_and_input(
long_indices=long_indices,
)
)
)
return (model, inputs)


Expand Down Expand Up @@ -297,6 +315,7 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
Expand All @@ -319,6 +338,7 @@ def sharding_single_rank_test(
batch_size=batch_size,
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
41 changes: 41 additions & 0 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,44 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)


@skip_if_asan_class
class TDSequenceModelParallelTest(SequenceModelParallelTest):

def test_sharding_variable_batch(self) -> None:
pass

def _test_sharding(
self,
sharders: List[TestEmbeddingCollectionSharder],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
local_size=local_size,
model_class=model_class,
tables=self.tables,
embedding_groups=self.embedding_groups,
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
input_type="td",
)
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=dim_emb,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
Expand All @@ -169,7 +169,7 @@ def main(
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=dim_emb,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
Expand Down
10 changes: 8 additions & 2 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt


@torch.fx.wrap
Expand Down Expand Up @@ -218,7 +219,10 @@ def __init__(
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
self.reset_parameters()

def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Expand All @@ -229,6 +233,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
KeyedTensor
"""
flat_feature_names: List[str] = []
features = maybe_td_to_kjt(features, None)
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
Expand Down Expand Up @@ -448,7 +453,7 @@ def __init__( # noqa C901

def forward(
self,
features: KeyedJaggedTensor,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> Dict[str, JaggedTensor]:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
Expand All @@ -461,6 +466,7 @@ def forward(
Dict[str, JaggedTensor]
"""

features = maybe_td_to_kjt(features, None)
feature_embeddings: Dict[str, JaggedTensor] = {}
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
for i, emb_module in enumerate(self.embeddings.values()):
Expand Down

0 comments on commit 8e11829

Please sign in to comment.