Skip to content

Commit

Permalink
2D for embeddingcollection (#2737)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2737

Adding support for EmbeddingCollection modules in 2D parallel. This supports all sharding types that are supported for EC. Also fixes TWRW DTensor.Placement in 2D case.

Reviewed By: kausv

Differential Revision: D68980589

fbshipit-source-id: 3cb7b060a91327102d64099b9a5490bbd1ed6dc3
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 12, 2025
1 parent a48d0ff commit 5bbae48
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 13 deletions.
27 changes: 21 additions & 6 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
from torch.distributed._tensor import DTensor
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -69,13 +70,16 @@
QuantizedCommCodecs,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardMetadata,
)
from torchrec.distributed.utils import (
add_params_from_parameter_sharding,
convert_to_fbgemm_types,
create_global_tensor_shape_stride_from_metadata,
maybe_annotate_embedding_event,
merge_fused_params,
none_throws,
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
Expand Down Expand Up @@ -534,12 +538,9 @@ def __init__(
if table_name in self._table_names
},
)
# output parameters as DTensor in state dict
self._output_dtensor: bool = (
fused_params.get("output_dtensor", False) if fused_params else False
)

self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = env.output_dtensor
# TODO get rid of get_ec_index_dedup global flag
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
Expand Down Expand Up @@ -842,6 +843,14 @@ def _initialize_torch_state(self) -> None: # noqa
)
)
else:
shape, stride = create_global_tensor_shape_stride_from_metadata(
none_throws(self.module_sharding_plan[table_name]),
(
self._env.node_group_size
if isinstance(self._env, ShardingEnv2D)
else get_local_size(self._env.world_size)
),
)
# empty shard case
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand All @@ -851,6 +860,8 @@ def _initialize_torch_state(self) -> None: # noqa
),
device_mesh=self._env.device_mesh,
run_check=False,
shape=shape,
stride=stride,
)
)
else:
Expand All @@ -861,7 +872,11 @@ def _initialize_torch_state(self) -> None: # noqa
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=self._env.process_group,
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

Expand Down
7 changes: 4 additions & 3 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import torch.distributed as dist
from torch.distributed._tensor import Shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.distributed_c10d import get_process_group_ranks
from torchrec.distributed.comm import (
get_local_size,
Expand Down Expand Up @@ -165,10 +165,11 @@ def _shard(

dtensor_metadata = None
if self._env.output_dtensor:
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=placements,
placements=(
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
),
size=(
info.embedding_config.num_embeddings,
info.embedding_config.embedding_dim,
Expand Down
Loading

0 comments on commit 5bbae48

Please sign in to comment.