From 8f2e7debc52c71bdf14a9b8c759f0ca1cb3b3a96 Mon Sep 17 00:00:00 2001 From: Liangbei Xu Date: Tue, 14 Jan 2025 11:41:19 -0800 Subject: [PATCH] refactor sharding plan log (#2676) Summary: Refactor sharding plan stats logging for better readability and less function complexity Reviewed By: ge0405 Differential Revision: D67376920 --- torchrec/distributed/planner/stats.py | 575 +++++++++++++++----------- 1 file changed, 332 insertions(+), 243 deletions(-) diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 1e4df9214..649035327 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -12,7 +12,18 @@ import math import statistics from collections import defaultdict -from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) from torch import nn @@ -183,38 +194,9 @@ def log( compute_kernels_to_count = defaultdict(int) compute_kernels_to_storage = defaultdict(lambda: Storage(0, 0)) - reserved_hbm_percent = ( - storage_reservation._percentage - if isinstance( - storage_reservation, - ( - FixedPercentageStorageReservation, - HeuristicalStorageReservation, - InferenceStorageReservation, - ), - ) - else 0.0 - ) - dense_storage = ( - storage_reservation._dense_storage - if isinstance( - storage_reservation, - (HeuristicalStorageReservation, InferenceStorageReservation), - ) - and storage_reservation._dense_storage is not None - else Storage(0, 0) + reserved_hbm_percent, dense_storage, kjt_storage = _compute_storage( + storage_reservation=storage_reservation ) - assert dense_storage - kjt_storage = ( - storage_reservation._kjt_storage - if isinstance( - storage_reservation, - (HeuristicalStorageReservation, InferenceStorageReservation), - ) - and storage_reservation._kjt_storage - else Storage(0, 0) - ) - assert kjt_storage for sharding_option in best_plan: fqn = sharding_option.fqn @@ -247,220 +229,30 @@ def log( stats[rank]["input_sizes"] += input_sizes[i] stats[rank]["output_sizes"] += output_sizes[i] - used_hbm = [0] * topology.world_size - used_ddr = [0] * topology.world_size - perf = [ - Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) - for _ in range(topology.world_size) - ] - for sharding_option in best_plan: - for shard in sharding_option.shards: - shard_storage = cast(Storage, shard.storage) - rank = cast(int, shard.rank) - used_hbm[rank] += shard_storage.hbm - used_ddr[rank] += shard_storage.ddr - perf[rank] += cast(Perf, shard.perf) - - used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in used_hbm] - used_ddr = [ddr + dense_storage.ddr + kjt_storage.ddr for ddr in used_ddr] - - table: List[List[Union[str, int]]] = [ - [ - "Rank", - "HBM (GB)", - "DDR (GB)", - "Perf (ms)", - "Input (MB)", - "Output (MB)", - "Shards", - ], - [ - "------", - "----------", - "----------", - "-----------", - "------------", - "-------------", - "--------", - ], - ] - - for rank, device in enumerate(topology.devices): - used_hbm_gb = bytes_to_gb(used_hbm[rank]) - used_hbm_ratio = ( - used_hbm[rank] / ((1 - reserved_hbm_percent) * device.storage.hbm) - if topology.compute_device == "cuda" - and ((1 - reserved_hbm_percent) * device.storage.hbm) != 0 - else 0 - ) - used_ddr_gb = bytes_to_gb(used_ddr[rank]) - used_ddr_ratio = ( - used_ddr[rank] / device.storage.ddr if device.storage.ddr > 0 else 0 - ) - for sharding_type in used_sharding_types: - if sharding_type not in stats[rank]["type"]: - stats[rank]["type"][sharding_type] = 0 + used_hbm, used_ddr, perf = _compute_mem_usage_and_perf( + topology=topology, + best_plan=best_plan, + dense_storage=dense_storage, + kjt_storage=kjt_storage, + ) - rank_hbm = f"{round(used_hbm_gb, 3)} ({used_hbm_ratio:.0%})" - rank_ddr = f"{round(used_ddr_gb, 3)} ({used_ddr_ratio:.0%})" - rank_perf = _format_perf_breakdown(perf[rank]) - rank_input = f"{round(stats[rank]['input_sizes'], 2)}" - rank_output = f"{round(stats[rank]['output_sizes'], 2)}" - rank_shards = " ".join( - f"{sharding_type}: {num_tables}" - for sharding_type, num_tables in sorted(stats[rank]["type"].items()) - ) - table.append( - [ - rank, - rank_hbm, - rank_ddr, - rank_perf, - rank_input, - rank_output, - rank_shards, - ] - ) - formatted_table = _format_table(table) - self._width = max(self._width, len(formatted_table[0]) + 8) + formatted_table = self._log_rank_mem_usage_and_perf( + topology=topology, + used_hbm=used_hbm, + used_ddr=used_ddr, + perf=perf, + stats=stats, + used_sharding_types=used_sharding_types, + reserved_hbm_percent=reserved_hbm_percent, + ) if debug: - param_table: List[List[Union[str, int]]] = [ - [ - "FQN", - "Sharding", - "Compute Kernel", - "Perf (ms)", - "Storage (HBM, DDR)", - "Cache Load Factor", - "Sum Pooling Factor", - "Sum Num Poolings", - "Num Indices", - "Output", - "Weighted", - "Sharder", - "Features", - "Emb Dim (CW Dim)", - "Hash Size", - "Ranks", - ], - [ - "-----", # FQN - "----------", # Sharding - "----------------", # Compute Kernel - "-----------", # Perf (ms) - "--------------------", # Storage (HBM, DDR) - "-------------------", # Cache Load Factor - "--------------------", # Sum Pooling Factor - "------------------", # Sum Num Poolings - "-------------", # Num Indices - "--------", # Output - "----------", # Weighted - "---------", # Sharder - "----------", # Features - "------------------", # Emb Dim (CW Dim) - "-----------", # Hash Size - "-------", # Ranks - ], - ] - feat_batch_sizes = [ - ( - constraints[so.name].batch_sizes - if constraints and constraints.get(so.name) - else None - ) - for so in best_plan - ] - - sharder_map: Dict[str, ModuleSharder[nn.Module]] = { - get_sharder_name(sharder.module_type): sharder - # pyre-ignore - this is a ModuleSharder below - for sharder in sharders - if sharders - } - - if include_batch_sizes := any(feat_batch_sizes): - param_table[0].append("Batch Sizes") - param_table[1].append("-------------") - for i, so in enumerate(best_plan): - ranks = sorted([cast(int, shard.rank) for shard in so.shards]) - ranks = _collapse_consecutive_ranks(ranks) - - so_perf = Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) - for shard in so.shards: - so_perf += cast(Perf, shard.perf) - - shard_perfs = _format_perf_breakdown(so_perf) - - so_storage = Storage(hbm=0, ddr=0) - for shard in so.shards: - so_storage += cast(Storage, shard.storage) - - shard_storages = _format_storage_breakdown(so_storage) - - pooling_factor = str(round(sum(so.input_lengths), 3)) - num_poolings = ( - cast(List[float], constraints[so.name].num_poolings) - if constraints - and constraints.get(so.name) - and constraints[so.name].num_poolings - else [NUM_POOLINGS] * len(so.input_lengths) - ) - num_indices = str( - round(sum(x * y for x, y in zip(so.input_lengths, num_poolings)), 3) - ) - num_poolings = str(round(sum(num_poolings), 3)) - output = "pooled" if so.is_pooled else "sequence" - weighted = "weighted" if so.is_weighted else "unweighted" - sharder = sharder_map.get(get_sharder_name(type(so.module[1])), None) - sharder_name = type(sharder).__name__ - num_features = len(so.input_lengths) - embedding_dim = ( - f"{so.tensor.shape[1]} ({so.shards[0].size[1]})" - if so.sharding_type == ShardingType.COLUMN_WISE.value - or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value - or so.sharding_type == ShardingType.GRID_SHARD.value - else f"{so.tensor.shape[1]}" - ) - sharder_cache_load_factor = ( - sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] - if hasattr(sharder, "fused_params") and sharder.fused_params - else None - ) - cache_load_factor = "None" - # Surfacing cache load factor does not make sense if not using uvm caching. - if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: - cache_load_factor = str( - so.cache_load_factor - if so.cache_load_factor is not None - else sharder_cache_load_factor - ) - hash_size = so.tensor.shape[0] - param_table.append( - [ - so.fqn, - _get_sharding_type_abbr(so.sharding_type), - so.compute_kernel, - shard_perfs, - shard_storages, - cache_load_factor, - pooling_factor, - num_poolings, - num_indices, - output, - weighted, - sharder_name, - num_features, - embedding_dim, - hash_size, - ",".join(ranks) if sharding_plan.plan else "None", - ] - ) - if include_batch_sizes: - bs = feat_batch_sizes[i] - param_table[-1].append(_reduce_int_list(bs) if bs else "n/a") - formatted_param_table = _format_table(param_table) - self._width = max(self._width, len(formatted_param_table[0]) + 6) + formatted_param_table = self._log_sharding_plan( + best_plan=best_plan, + sharding_plan=sharding_plan, + sharders=sharders, + constraints=constraints, + ) self._stats_table.clear() self._stats_table.append("#" * self._width) @@ -848,6 +640,239 @@ def _log_compute_kernel_stats( for compute_kernel_count in compute_kernels_count: self._stats_table.append(f"# {compute_kernel_count : <{self._width-6}}#") + def _log_rank_mem_usage_and_perf( + self, + topology: Topology, + used_ddr: List[int], + used_hbm: List[int], + perf: List[Perf], + stats: Dict[int, Dict[str, Any]], + used_sharding_types: Set[str], + reserved_hbm_percent: float, + ) -> List[str]: + table: List[List[Union[str, int]]] = [ + [ + "Rank", + "HBM (GB)", + "DDR (GB)", + "Perf (ms)", + "Input (MB)", + "Output (MB)", + "Shards", + ], + [ + "------", + "----------", + "----------", + "-----------", + "------------", + "-------------", + "--------", + ], + ] + + for rank, device in enumerate(topology.devices): + used_hbm_gb = bytes_to_gb(used_hbm[rank]) + used_hbm_ratio = ( + used_hbm[rank] / ((1 - reserved_hbm_percent) * device.storage.hbm) + if topology.compute_device == "cuda" + and ((1 - reserved_hbm_percent) * device.storage.hbm) != 0 + else 0 + ) + used_ddr_gb = bytes_to_gb(used_ddr[rank]) + used_ddr_ratio = ( + used_ddr[rank] / device.storage.ddr if device.storage.ddr > 0 else 0 + ) + for sharding_type in used_sharding_types: + if sharding_type not in stats[rank]["type"]: + stats[rank]["type"][sharding_type] = 0 + + rank_hbm = f"{round(used_hbm_gb, 3)} ({used_hbm_ratio:.0%})" + rank_ddr = f"{round(used_ddr_gb, 3)} ({used_ddr_ratio:.0%})" + rank_perf = _format_perf_breakdown(perf[rank]) + rank_input = f"{round(stats[rank]['input_sizes'], 2)}" + rank_output = f"{round(stats[rank]['output_sizes'], 2)}" + rank_shards = " ".join( + f"{sharding_type}: {num_tables}" + for sharding_type, num_tables in sorted(stats[rank]["type"].items()) + ) + table.append( + [ + rank, + rank_hbm, + rank_ddr, + rank_perf, + rank_input, + rank_output, + rank_shards, + ] + ) + formatted_table = _format_table(table) + self._width = max(self._width, len(formatted_table[0]) + 8) + return formatted_table + + def _log_sharding_plan( + self, + best_plan: List[ShardingOption], + sharding_plan: ShardingPlan, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + sharders: Optional[List[ModuleSharder[nn.Module]]] = None, + ) -> List[str]: + def _get_embedding_dim(so: ShardingOption) -> str: + embedding_dim = ( + f"{so.tensor.shape[1]} ({so.shards[0].size[1]})" + if so.sharding_type == ShardingType.COLUMN_WISE.value + or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value + or so.sharding_type == ShardingType.GRID_SHARD.value + else f"{so.tensor.shape[1]}" + ) + return embedding_dim + + def _get_num_poolings( + constraints: Optional[Dict[str, ParameterConstraints]], so: ShardingOption + ) -> List[float]: + num_poolings = ( + cast(List[float], constraints[so.name].num_poolings) + if constraints + and constraints.get(so.name) + and constraints[so.name].num_poolings + else [NUM_POOLINGS] * len(so.input_lengths) + ) + return num_poolings + + def _get_cache_load_factor( + sharder: Optional[ModuleSharder[nn.Module]], so: ShardingOption + ) -> None: + sharder_cache_load_factor = ( + sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + cache_load_factor = "None" + # Surfacing cache load factor does not make sense if not using uvm caching. + if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: + cache_load_factor = str( + so.cache_load_factor + if so.cache_load_factor is not None + else sharder_cache_load_factor + ) + return cache_load_factor + + param_table = [ + [ + "FQN", + "Sharding", + "Compute Kernel", + "Perf (ms)", + "Storage (HBM, DDR)", + "Cache Load Factor", + "Sum Pooling Factor", + "Sum Num Poolings", + "Num Indices", + "Output", + "Weighted", + "Sharder", + "Features", + "Emb Dim (CW Dim)", + "Hash Size", + "Ranks", + ], + [ + "-----", # FQN + "----------", # Sharding + "----------------", # Compute Kernel + "-----------", # Perf (ms) + "--------------------", # Storage (HBM, DDR) + "-------------------", # Cache Load Factor + "--------------------", # Sum Pooling Factor + "------------------", # Sum Num Poolings + "-------------", # Num Indices + "--------", # Output + "----------", # Weighted + "---------", # Sharder + "----------", # Features + "------------------", # Emb Dim (CW Dim) + "-----------", # Hash Size + "-------", # Ranks + ], + ] + feat_batch_sizes = [ + ( + constraints[so.name].batch_sizes + if constraints and constraints.get(so.name) + else None + ) + for so in best_plan + ] + + sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + get_sharder_name(sharder.module_type): sharder + # pyre-ignore - this is a ModuleSharder below + for sharder in sharders + if sharders + } + + if include_batch_sizes := any(feat_batch_sizes): + param_table[0].append("Batch Sizes") + param_table[1].append("-------------") + for i, so in enumerate(best_plan): + ranks = sorted([cast(int, shard.rank) for shard in so.shards]) + ranks = _collapse_consecutive_ranks(ranks) + + so_perf = Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) + for shard in so.shards: + so_perf += cast(Perf, shard.perf) + + shard_perfs = _format_perf_breakdown(so_perf) + + so_storage = Storage(hbm=0, ddr=0) + for shard in so.shards: + so_storage += cast(Storage, shard.storage) + + shard_storages = _format_storage_breakdown(so_storage) + + pooling_factor = str(round(sum(so.input_lengths), 3)) + num_poolings = _get_num_poolings(constraints, so) + num_indices = str( + round(sum(x * y for x, y in zip(so.input_lengths, num_poolings)), 3) + ) + num_poolings = str(round(sum(num_poolings), 3)) + output = "pooled" if so.is_pooled else "sequence" + weighted = "weighted" if so.is_weighted else "unweighted" + sharder = sharder_map.get(get_sharder_name(type(so.module[1])), None) + sharder_name = type(sharder).__name__ + num_features = len(so.input_lengths) + embedding_dim = _get_embedding_dim(so) + cache_load_factor = _get_cache_load_factor(sharder, so) + hash_size = so.tensor.shape[0] + param_table.append( + # pyre-ignore[6] + [ + so.fqn, + _get_sharding_type_abbr(so.sharding_type), + so.compute_kernel, + shard_perfs, + shard_storages, + cache_load_factor, + pooling_factor, + num_poolings, + num_indices, + output, + weighted, + sharder_name, + num_features, + embedding_dim, + hash_size, + ",".join(ranks) if sharding_plan.plan else "None", + ] + ) + if include_batch_sizes: + bs = feat_batch_sizes[i] + param_table[-1].append(_reduce_int_list(bs) if bs else "n/a") + formatted_param_table = _format_table(param_table) # pyre-ignore[6] + self._width = max(self._width, len(formatted_param_table[0]) + 6) + return formatted_param_table + def _generate_rank_hbm_stats( per_rank_hbm: List[int], func: Callable[[Iterable[float]], float] @@ -909,6 +934,70 @@ def _format_perf_breakdown(perf: Perf) -> str: return f"{str(round(perf.total, 3))} ({breakdown_string})" +def _compute_storage( + storage_reservation: StorageReservation, +) -> Tuple[float, Storage, Storage]: + reserved_hbm_percent = ( + storage_reservation._percentage + if isinstance( + storage_reservation, + ( + FixedPercentageStorageReservation, + HeuristicalStorageReservation, + InferenceStorageReservation, + ), + ) + else 0.0 + ) + + dense_storage = ( + storage_reservation._dense_storage + if isinstance( + storage_reservation, + (HeuristicalStorageReservation, InferenceStorageReservation), + ) + and storage_reservation._dense_storage is not None + else Storage(0, 0) + ) + assert dense_storage + kjt_storage = ( + storage_reservation._kjt_storage + if isinstance( + storage_reservation, + (HeuristicalStorageReservation, InferenceStorageReservation), + ) + and storage_reservation._kjt_storage + else Storage(0, 0) + ) + assert kjt_storage + return reserved_hbm_percent, dense_storage, kjt_storage + + +def _compute_mem_usage_and_perf( + topology: Topology, + best_plan: List[ShardingOption], + dense_storage: Storage, + kjt_storage: Storage, +) -> Tuple[List[int], List[int], List[Perf]]: + used_hbm = [0] * topology.world_size + used_ddr = [0] * topology.world_size + perf = [ + Perf(fwd_compute=0, fwd_comms=0, bwd_compute=0, bwd_comms=0) + for _ in range(topology.world_size) + ] + for sharding_option in best_plan: + for shard in sharding_option.shards: + shard_storage = cast(Storage, shard.storage) + rank = cast(int, shard.rank) + used_hbm[rank] += shard_storage.hbm + used_ddr[rank] += shard_storage.ddr + perf[rank] += cast(Perf, shard.perf) + + used_hbm = [hbm + dense_storage.hbm + kjt_storage.hbm for hbm in used_hbm] + used_ddr = [ddr + dense_storage.ddr + kjt_storage.ddr for ddr in used_ddr] + return used_hbm, used_ddr, perf + + def _format_storage_breakdown(storage: Storage) -> str: storage_hbm = round(bytes_to_gb(storage.hbm), 3) storage_ddr = round(bytes_to_gb(storage.ddr), 3)