Skip to content

Commit

Permalink
Proportional Uneven RW Inference Sharding (pytorch#2734)
Browse files Browse the repository at this point in the history
Summary:

Support bucketization aware inference sharding in TGIF for ZCH bucket boundaries from training.
A "best effort" sharding is performed across bucket boundaries proportional to memory list.

* Added bucketization awareness to RW sharding,
* TGIF sharding now ensures at most 1 bucket difference across equal memory uneven shards as opposed to previous logic of remainder rows to last shard
* InferRWSparseDist checks for customized embedding_shard_metadata for uneven shards before dividing evenly

Differential Revision: D69057627
  • Loading branch information
kausv authored and facebook-github-bot committed Feb 13, 2025
1 parent fd45bdc commit 1b6932d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def sharded_tbes_weights_spec(
shard_sizes: List[int] = [table.local_rows, table.local_cols]
shard_offsets: List[int] = table_metadata.shard_offsets
s: str = "embedding_bags" if is_sqebc else "embeddings"
s = ("_embedding_module." if is_sqmcec else "") + s
unsharded_fqn_weight: str = f"{module_fqn}.{s}.{table_name}.weight"

sharded_fqn_weight: str = (
Expand Down
20 changes: 12 additions & 8 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,14 +661,18 @@ def __init__(
self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets

self.feature_block_sizes: List[int] = []
for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if feature_total_num_buckets is not None:
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
self.feature_block_sizes.append(
(hash_size + block_divisor - 1) // block_divisor
)
if embedding_shard_metadata is not None:
assert len(embedding_shard_metadata) == len(feature_hash_sizes)
self.feature_block_sizes = [0] * len(feature_hash_sizes)
else:
for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if feature_total_num_buckets is not None:
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
self.feature_block_sizes.append(
(hash_size + block_divisor - 1) // block_divisor
)
self.tensor_cache: Dict[
str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
] = {}
Expand Down

0 comments on commit 1b6932d

Please sign in to comment.