Skip to content

Commit

Permalink
2025-01-29 nightly release (c6f41aa)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 29, 2025
1 parent fbf0175 commit 1db2788
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,9 @@ def __init__(
)

self._remap_sharding_plan(
plan, self._global_rank, world_size // sharding_group_size
plan=plan,
rank=self._global_rank,
num_nodes=world_size // sharding_group_size,
)
super().__init__(
module,
Expand All @@ -733,7 +735,7 @@ def sync(self, include_optimizer_state: bool = True) -> None:
"""
Syncs the DMP weights across the allreduce (inter) process group
This method is called after each forward pass to synchronize the weights of the sharded modules.
This method is called after each train step to synchronize the weights of the sharded modules.
It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights,
which averages the weights across all processes in the inter-process group.
Expand Down Expand Up @@ -782,10 +784,10 @@ def _create_process_groups(
replication process group, and allreduce process group.
"""
peer_matrix = []
step = world_size // local_size
num_nodes = world_size // local_size

for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
peers = [num_nodes * r + group_rank for r in range(local_size)]
peer_matrix.append(peers)

mesh = DeviceMesh(
Expand All @@ -805,7 +807,9 @@ def _create_process_groups(

return mesh, sharding_pg, replica_pg

def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None:
def _remap_sharding_plan(
self, plan: ShardingPlan, rank: int, num_nodes: int
) -> None:
"""
Remaps the sharding plan to the local replica process group ranks
ShardingPlan is remapped inplace.
Expand All @@ -816,22 +820,22 @@ def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None
Args:
plan (ShardingPlan): The original sharding plan.
global_rank (int): The global rank of the current process.
step (int): The number of nodes.
num_nodes (int): The number of nodes.
"""

group_start = rank % step
group_start = rank % num_nodes
for key in plan.plan:
# pyre-ignore[16]
for _, param_sharding in plan.plan[key].items():
new_ranks = []
for shard_rank in param_sharding.ranks:
new_ranks.append(shard_rank * step + group_start)
new_ranks.append(shard_rank * num_nodes + group_start)
param_sharding.ranks = new_ranks
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
shards = param_sharding.sharding_spec.shards
if shards is not None:
for shard in shards:
shard_rank = shard.placement._rank * step + group_start
shard_rank = shard.placement._rank * num_nodes + group_start
shard.placement = _remote_device(
f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}"
)
Expand Down

0 comments on commit 1db2788

Please sign in to comment.