diff --git a/float8_experimental/config.py b/float8_experimental/config.py index e190b2f..2e9eacf 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -58,10 +58,9 @@ class Float8LinearConfig: # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward: bool = True - # If True, then uses a tensor subclass for the fp8 linear module's weight that - # implements pre/post-all-gather methods to do fp8 all-gather with FSDP2. - # Only dynamic scaling is supported for now. - enable_fsdp_fp8_all_gather: bool = False + # If True, then uses a tensor subclass for the float8 linear module's weight that + # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. + enable_fsdp_float8_all_gather: bool = False # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 38e10e5..1d8519e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -467,7 +467,7 @@ def from_float( # 1. weight needs to be on the correct device to create the buffers # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized - if config.enable_fsdp_fp8_all_gather: + if config.enable_fsdp_float8_all_gather: if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index f40ad25..6d5719a 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -83,7 +83,7 @@ def world_size(self) -> int: def test_transformer_parity(self): self.run_subtests( { - "enable_fsdp_fp8_all_gather": [False, True], + "enable_fsdp_float8_all_gather": [False, True], "precompute": [False, True], "scaling_type_weight": [ TensorScalingType.DYNAMIC, @@ -96,12 +96,12 @@ def test_transformer_parity(self): def _test_transformer_parity( self, - enable_fsdp_fp8_all_gather: bool, + enable_fsdp_float8_all_gather: bool, precompute: bool, scaling_type_weight: TensorScalingType, compile_transformer_block: bool, ): - if not enable_fsdp_fp8_all_gather and precompute: + if not enable_fsdp_float8_all_gather and precompute: return elif scaling_type_weight is TensorScalingType.DELAYED and precompute: return @@ -110,7 +110,7 @@ def _test_transformer_parity( # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to # fp8 for that tied weight, incorrectly using fp8 for the embedding. - weight_tying = not enable_fsdp_fp8_all_gather + weight_tying = not enable_fsdp_float8_all_gather module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) float8_linear_config1 = Float8LinearConfig( @@ -125,7 +125,7 @@ def _test_transformer_parity( transformer_block = torch.compile(transformer_block, dynamic=False) ref_module.layers.register_module(layer_id, transformer_block) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), ) convert_to_float8_training( @@ -158,10 +158,10 @@ def _test_transformer_parity( @skip_if_lt_x_gpu(2) def test_transformer_memory(self): """Tests peak active memory in the forward and backward passes.""" - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_memory(enable_fsdp_fp8_all_gather) + for enable_fsdp_float8_all_gather in [False, True]: + self._test_transformer_memory(enable_fsdp_float8_all_gather) - def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool): torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage @@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility # requirement to use a smaller activation size float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, emulate=True, ) convert_to_float8_training(model, config=float8_linear_config) @@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # number is kept much smaller than the actual memory usage, which is on # the order of 100-200+ MB) buffer_mb = 16 - if enable_fsdp_fp8_all_gather: + if enable_fsdp_float8_all_gather: # Non-block parameters (fp32), 3x block non-linear-weight # parameters (fp32) and block linear-weight parameters (fp8) # (current all-gather, copy-out, and next all-gather), and other @@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Backward: loss.sum().backward() mem_mb = self._get_peak_active_memory_mb() - if enable_fsdp_fp8_all_gather: + if enable_fsdp_float8_all_gather: # Non-block parameters (fp32), 2x block non-linear weight # parameters (fp32) and block linear-weight parameters (fp8) # (current copy-out and next all-gather), 1x block gradients (fp32) @@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self): # Check for a single FSDP paramter group module_fp32 = self.init_single_module() float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, emulate=True, ) module = convert_to_float8_training( @@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module): module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, ) module_fp32 = convert_to_float8_training( module_fp32, config=float8_linear_config @@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=False, + enable_fsdp_float8_all_gather=False, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), ) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), @@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: + for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=False, + enable_fsdp_float8_all_gather=False, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), ) float8_linear_config2 = Float8LinearConfig( - enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_weight=Float8TensorCastConfig( scaling_type=scaling_type_weight ), @@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self): """ module = self.init_single_module() float8_linear_config = Float8LinearConfig( - enable_fsdp_fp8_all_gather=True, + enable_fsdp_float8_all_gather=True, cast_config_weight=Float8TensorCastConfig( scaling_type=TensorScalingType.DELAYED ),