Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Added eager fp8 all-gather for dynamic scaling (subclass) (#247)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #247

This PR requires pytorch/pytorch#122908.

This PR adds an option for FSDP fp8 all-gather with FSDP2, focusing on correctness first.
- We add a config option `enable_fsdp_fp8_all_gather: bool = False` in `float8_experimental.config` that can be set to `True` to enable the new behavior.
- We only support `Float8DynamicLinear` for now.
- For the fp8 all-gather, we make `Float8DynamicLinear.weight` to be a tensor subclass `Float8DynamicLinearWeightTensor`. This can be thought of as a `torch.Tensor` weight that knows how to cast itself to fp8. Since the cast logic is specific to dynamic scaling (i.e. it would not be the exact same for delayed scaling), I preferred to simply name the subclass `Float8DynamicLinearWeightTensor`.
- The subclass is a `__torch_dispatch__` wrapper subclass since that is the kind of subclass most compatible with `torch.compile`.
- The subclass defines `fsdp_pre_all_gather()` and `fsdp_post_all_gather()` to implement the fp8 all-gather.
- We include unit testing that covers eager-mode correctness and memory usage.

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D56192107

Pulled By: awgu

fbshipit-source-id: b9e36fde0e39ca59658ddef86932a0b1ed2ca461
  • Loading branch information
awgu authored and facebook-github-bot committed Apr 16, 2024
1 parent 31877bb commit ac065d0
Show file tree
Hide file tree
Showing 8 changed files with 691 additions and 34 deletions.
5 changes: 5 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, then uses a tensor subclass for the fp8 linear module's weight that
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2.
# Only dynamic scaling is supported for now.
enable_fsdp_fp8_all_gather = False
159 changes: 134 additions & 25 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,24 @@
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
"""

from typing import Any, Optional, Tuple

import float8_experimental.config as config

import torch
import torch.nn as nn
import torch.utils._pytree as pytree

from float8_experimental.float8_tensor import (
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale
from torch._prims_common import suggest_memory_format


@torch._dynamo.allow_in_graph
Expand All @@ -36,7 +45,6 @@ def forward(
@staticmethod
def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
# check to early return if already casted to float8
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
fp8_tensor = to_fp8_no_autograd(
Expand All @@ -55,31 +63,15 @@ def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, x):
# cast x to float8_e4m3fn if not using activation hooks
x_fp8 = self.cast_to_float8_e4m3fn(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)

x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward if not using activation hooks
y = self.cast_to_float8_e5m2_bw(y)

y = cast_to_float8_e5m2_bw(y, self.backward_config)
return y

def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
# check to early return if already casted to float8
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=self.forward_config
)

def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, self.backward_config)

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"""
Expand All @@ -96,8 +88,125 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"bias": False,
}
new_mod = cls(**super_kwargs)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def cast_to_float8_e4m3fn(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
)


def cast_to_float8_e5m2_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
# that the padded local tensor (and any transformations like copying to GPU)
# is of the subclass as well.
_ops_to_preserve_subclass = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.copy_.default,
torch.ops.aten.view.default,
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
}


class WeightWithDynamicFloat8CastTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
memory_format=suggest_memory_format(tensor),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
self._tensor = tensor
self._mm_config = mm_config

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
return WeightWithDynamicFloat8CastTensor(
args[0]._tensor, args[0]._mm_config
)
mm_config: Optional[ScaledMMConfig] = None

def unwrap(t):
nonlocal mm_config
if mm_config is None:
mm_config = t._mm_config
else:
mm_config = merge_mm_configs(mm_config, t._mm_config)
return t._tensor

args, kwargs = pytree.tree_map_only(
WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {})
)
out = func(*args, **kwargs)
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out
)

def __tensor_flatten__(self):
return ["_tensor"], self._mm_config

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
mm_config = flatten_spec
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)

def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
with torch.no_grad():
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code
new_amax = tensor_to_amax(x, distributed_reduction=True)
new_amax = tensor_to_amax(x, reduce_amax=True)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
Expand Down
16 changes: 12 additions & 4 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import (
cast_to_float8_e4m3fn,
cast_to_float8_e5m2_bw,
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
Expand All @@ -25,7 +29,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = mod.cast_to_float8_e4m3fn(input_tensor) # DTensor(Float8Tensor)
input_tensor = cast_to_float8_e4m3fn(
input_tensor, mod.forward_config
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
Expand All @@ -43,7 +49,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
) # DTensor(torch.Tensor)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = mod.cast_to_float8_e5m2_bw(outputs)
outputs = cast_to_float8_e5m2_bw(outputs, mod.backward_config)

# back to local tensor
return outputs.to_local() if use_local_output else outputs
Expand All @@ -70,7 +76,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = mod.cast_to_float8_e4m3fn(input_tensor) # DTensor(Float8Tensor)
input_tensor = cast_to_float8_e4m3fn(
input_tensor, mod.forward_config
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
Expand All @@ -87,7 +95,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
outputs = outputs.redistribute(placements=output_layouts, async_op=True)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = mod.cast_to_float8_e5m2_bw(outputs)
outputs = cast_to_float8_e5m2_bw(outputs, mod.backward_config)

# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
Expand Down
10 changes: 6 additions & 4 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,23 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
# happen elsewhere.
if distributed_reduction and dist.is_initialized():
if reduce_amax and dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)

return amax


@torch.no_grad()
def tensor_to_scale(x, float8_dtype):
amax = tensor_to_amax(x)
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
1 change: 1 addition & 0 deletions test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_dtensor.sh
pytest test/test_fsdp2/test_fsdp2_eager.py

echo "all tests successful"
84 changes: 84 additions & 0 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import contextlib
from typing import List, Type

import float8_experimental.config as config

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history


def check_parity_no_mp(
test_cls,
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
module_cls: Type,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(model(local_inp).sum())
losses[-1].backward()
if model is ref_model:
for param in model.parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())
if module_cls is Float8Linear:
sync_float8_amax_and_scale_history(model)
optim.step()
test_cls.assertEqual(losses[0], losses[1])


def check_parity_bf16_mp(
test_cls,
ref_model: nn.Module,
ref_model_bf16: nn.Module,
ref_optim: torch.optim.Optimizer,
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
module_cls: Type,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for model, optim in (
(ref_model_bf16, ref_optim),
(fsdp_model, fsdp_optim),
):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(model(local_inp).sum())
losses[-1].backward()
if model is ref_model_bf16:
for param_bf16, param_fp32 in zip(
ref_model_bf16.parameters(), ref_model.parameters()
):
dist.all_reduce(param_bf16.grad)
param_bf16.grad.div_(dist.get_world_size())
param_fp32.grad = param_bf16.grad.float()
param_bf16.grad = None
if module_cls is Float8Linear:
sync_float8_amax_and_scale_history(model)
optim.step()
for param_fp32, param_bf16 in zip(
ref_model.parameters(), ref_model_bf16.parameters()
):
param_bf16.detach().copy_(param_fp32)
test_cls.assertEqual(losses[0], losses[1])


@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
prev = config.enable_fsdp_fp8_all_gather
dist.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
dist.barrier()
config.enable_fsdp_fp8_all_gather = prev
Loading

0 comments on commit ac065d0

Please sign in to comment.