Skip to content

Commit

Permalink
Proportional Uneven RW Inference Sharding (#2734)
Browse files Browse the repository at this point in the history
Summary:

Support bucketization aware inference sharding in TGIF for ZCH bucket boundaries from training.
A "best effort" sharding is performed across bucket boundaries proportional to memory list.

* Added bucketization awareness to RW sharding,
* TGIF sharding now ensures at most 1 bucket difference across equal memory uneven shards as opposed to previous logic of remainder rows to last shard
* InferRWSparseDist checks for customized embedding_shard_metadata for uneven shards before dividing evenly

Differential Revision: D69057627
  • Loading branch information
kausv authored and facebook-github-bot committed Feb 14, 2025
1 parent 7d161d9 commit dde07fb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/quant_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def sharded_tbes_weights_spec(
shard_sizes: List[int] = [table.local_rows, table.local_cols]
shard_offsets: List[int] = table_metadata.shard_offsets
s: str = "embedding_bags" if is_sqebc else "embeddings"
s = ("_embedding_module." if is_sqmcec else "") + s
unsharded_fqn_weight: str = f"{module_fqn}.{s}.{table_name}.weight"

sharded_fqn_weight: str = (
Expand Down
6 changes: 4 additions & 2 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,13 @@ def __init__(
self._world_size: int = world_size
self._num_features = num_features
self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets

self.feature_block_sizes: List[int] = []
for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if feature_total_num_buckets is not None:
if (
feature_total_num_buckets is not None
and embedding_shard_metadata is None
):
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
self.feature_block_sizes.append(
Expand Down

0 comments on commit dde07fb

Please sign in to comment.