Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rename config.enable_fsdp_fp8_all_gather to use float8 (#332)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #332

old: `enable_fsdp_fp8_all_gather`
new: `enable_fsdp_float8_all_gather`

this is to match the `float8` naming elsewhere

Reviewed By: weifengpy

Differential Revision: D60252072

fbshipit-source-id: 5e240f0a97b647aa4f43a63dab3f03f68fd3b405
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 25, 2024
1 parent 701647b commit eff4ba6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
7 changes: 3 additions & 4 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ class Float8LinearConfig:
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward: bool = True

# If True, then uses a tensor subclass for the fp8 linear module's weight that
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather: bool = False
# If True, then uses a tensor subclass for the float8 linear module's weight that
# implements pre/post-all-gather methods to do float8 all-gather with FSDP2.
enable_fsdp_float8_all_gather: bool = False

# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def from_float(
# 1. weight needs to be on the correct device to create the buffers
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_fp8_all_gather:
if config.enable_fsdp_float8_all_gather:
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
Expand Down
40 changes: 20 additions & 20 deletions test/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def world_size(self) -> int:
def test_transformer_parity(self):
self.run_subtests(
{
"enable_fsdp_fp8_all_gather": [False, True],
"enable_fsdp_float8_all_gather": [False, True],
"precompute": [False, True],
"scaling_type_weight": [
TensorScalingType.DYNAMIC,
Expand All @@ -96,12 +96,12 @@ def test_transformer_parity(self):

def _test_transformer_parity(
self,
enable_fsdp_fp8_all_gather: bool,
enable_fsdp_float8_all_gather: bool,
precompute: bool,
scaling_type_weight: TensorScalingType,
compile_transformer_block: bool,
):
if not enable_fsdp_fp8_all_gather and precompute:
if not enable_fsdp_float8_all_gather and precompute:
return
elif scaling_type_weight is TensorScalingType.DELAYED and precompute:
return
Expand All @@ -110,7 +110,7 @@ def _test_transformer_parity(
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
weight_tying = not enable_fsdp_fp8_all_gather
weight_tying = not enable_fsdp_float8_all_gather
module = self.init_transformer(weight_tying=weight_tying).cuda()
ref_module = copy.deepcopy(module)
float8_linear_config1 = Float8LinearConfig(
Expand All @@ -125,7 +125,7 @@ def _test_transformer_parity(
transformer_block = torch.compile(transformer_block, dynamic=False)
ref_module.layers.register_module(layer_id, transformer_block)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
)
convert_to_float8_training(
Expand Down Expand Up @@ -158,10 +158,10 @@ def _test_transformer_parity(
@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
"""Tests peak active memory in the forward and backward passes."""
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_memory(enable_fsdp_fp8_all_gather)
for enable_fsdp_float8_all_gather in [False, True]:
self._test_transformer_memory(enable_fsdp_float8_all_gather)

def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_memory(self, enable_fsdp_float8_all_gather: bool):
torch.manual_seed(42)
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage
Expand All @@ -184,7 +184,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
# requirement to use a smaller activation size
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
emulate=True,
)
convert_to_float8_training(model, config=float8_linear_config)
Expand Down Expand Up @@ -231,7 +231,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# number is kept much smaller than the actual memory usage, which is on
# the order of 100-200+ MB)
buffer_mb = 16
if enable_fsdp_fp8_all_gather:
if enable_fsdp_float8_all_gather:
# Non-block parameters (fp32), 3x block non-linear-weight
# parameters (fp32) and block linear-weight parameters (fp8)
# (current all-gather, copy-out, and next all-gather), and other
Expand All @@ -255,7 +255,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Backward:
loss.sum().backward()
mem_mb = self._get_peak_active_memory_mb()
if enable_fsdp_fp8_all_gather:
if enable_fsdp_float8_all_gather:
# Non-block parameters (fp32), 2x block non-linear weight
# parameters (fp32) and block linear-weight parameters (fp8)
# (current copy-out and next all-gather), 1x block gradients (fp32)
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_weight_subclass_dynamic(self):
# Check for a single FSDP paramter group
module_fp32 = self.init_single_module()
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
emulate=True,
)
module = convert_to_float8_training(
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_expected_all_gather_size(module: nn.Module):
module_fp32 = self.init_single_module()
ref_module = copy.deepcopy(module_fp32)
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
)
module_fp32 = convert_to_float8_training(
module_fp32, config=float8_linear_config
Expand Down Expand Up @@ -418,15 +418,15 @@ def test_fp32_fp8_single_module_parity(self):
[False, True],
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
)
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
float8_linear_config1 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=False,
enable_fsdp_float8_all_gather=False,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
Expand Down Expand Up @@ -466,15 +466,15 @@ def test_fp32_fp8_multi_module_parity(self):
[False, True],
[TensorScalingType.DYNAMIC, TensorScalingType.DELAYED],
)
for enable_fsdp_fp8_all_gather, scaling_type_weight in choices:
for enable_fsdp_float8_all_gather, scaling_type_weight in choices:
float8_linear_config1 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=False,
enable_fsdp_float8_all_gather=False,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
)
float8_linear_config2 = Float8LinearConfig(
enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=Float8TensorCastConfig(
scaling_type=scaling_type_weight
),
Expand Down Expand Up @@ -545,7 +545,7 @@ def test_delayed_scaling_inplace_update(self):
"""
module = self.init_single_module()
float8_linear_config = Float8LinearConfig(
enable_fsdp_fp8_all_gather=True,
enable_fsdp_float8_all_gather=True,
cast_config_weight=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
Expand Down

0 comments on commit eff4ba6

Please sign in to comment.