Skip to content

Commit

Permalink
Revert D66521351 (#2701)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2701

This diff reverts D66521351
Need to revert this to fix lowering import error breaking aps tests

Reviewed By: PoojaAg18

Differential Revision: D68528333

fbshipit-source-id: d20da94f0c1b37944f1e985f3398e0684092388b
  • Loading branch information
kausv authored and facebook-github-bot committed Jan 23, 2025
1 parent f3d34fc commit d5a991b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 84 deletions.
13 changes: 2 additions & 11 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import torch
from tensordict import TensorDict
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
Expand Down Expand Up @@ -91,7 +90,6 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context(
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: TypeUnion[KeyedJaggedTensor, TensorDict],
features: KeyedJaggedTensor,
) -> Awaitable[Awaitable[KJTList]]:
need_permute: bool = True
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._features_order:
feature_keys = [feature_keys[i] for i in self._features_order]
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
Expand All @@ -1218,7 +1209,7 @@ def input_dist(
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if need_permute and self._features_order:
if self._features_order:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
Expand Down
32 changes: 6 additions & 26 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def gen_model_and_input(
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -178,9 +177,9 @@ def gen_model_and_input(
feature_processor_modules=feature_processor_modules,
)
inputs = []
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
for _ in range(num_inputs):
inputs.append(
for _ in range(num_inputs):
inputs.append(
(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
Expand All @@ -189,26 +188,8 @@ def gen_model_and_input(
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
)
elif generate == ModelInput.generate:
for _ in range(num_inputs):
inputs.append(
ModelInput.generate(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
)
)
else:
for _ in range(num_inputs):
inputs.append(
cast(ModelInputCallable, generate)(
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
Expand All @@ -219,6 +200,7 @@ def gen_model_and_input(
long_indices=long_indices,
)
)
)
return (model, inputs)


Expand Down Expand Up @@ -315,7 +297,6 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
Expand All @@ -338,7 +319,6 @@ def sharding_single_rank_test(
batch_size=batch_size,
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
41 changes: 0 additions & 41 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,44 +376,3 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)


@skip_if_asan_class
class TDSequenceModelParallelTest(SequenceModelParallelTest):

def test_sharding_variable_batch(self) -> None:
pass

def _test_sharding(
self,
sharders: List[TestEmbeddingCollectionSharder],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
local_size=local_size,
model_class=model_class,
tables=self.tables,
embedding_groups=self.embedding_groups,
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
input_type="td",
)
8 changes: 2 additions & 6 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,7 @@ def __init__(
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
self.reset_parameters()

def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> KeyedTensor:
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Expand Down Expand Up @@ -453,7 +450,7 @@ def __init__( # noqa C901

def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
Expand All @@ -466,7 +463,6 @@ def forward(
Dict[str, JaggedTensor]
"""

features = maybe_td_to_kjt(features, None)
feature_embeddings: Dict[str, JaggedTensor] = {}
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
for i, emb_module in enumerate(self.embeddings.values()):
Expand Down

0 comments on commit d5a991b

Please sign in to comment.