Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "[wip] make all 3 gemms in Float8Linear configurable"
Browse files Browse the repository at this point in the history
Summary:

not ready for review yet

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 19, 2024
1 parent e3545db commit bfe17f9
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 95 deletions.
3 changes: 2 additions & 1 deletion float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole])
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
5 changes: 2 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,9 @@ def from_float(
if config.enable_fsdp_fp8_all_gather:
if scaling_type_w is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
# TODO(this PR): change callsites below to linear_mm_config
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
new_mod.linear_mm_config,
)
)
else:
Expand All @@ -485,7 +484,7 @@ def from_float(
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.linear_mm_config,
new_mod.is_amax_initialized,
)
)
Expand Down
42 changes: 17 additions & 25 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import torch

from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_tensor import (
choose_scaled_mm_config,
Float8Tensor,
ScaledMMConfig,
)
from float8_experimental.float8_tensor import choose_scaled_mm_config, Float8Tensor
from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul

from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -53,7 +49,7 @@ def float8_desugar_op(aten_op, args, kwargs=None):
new_data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)

Expand All @@ -67,7 +63,7 @@ def make_float8(data):
data,
args[0]._scale,
args[0]._orig_dtype,
args[0]._mm_config,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)

Expand All @@ -82,7 +78,7 @@ def float8_cat(aten_op, args, kwargs=None):

orig_dtype = chunked_tensors[0]._orig_dtype
scale = chunked_tensors[0]._scale
mm_config = chunked_tensors[0]._mm_config
mm_config = chunked_tensors[0]._linear_mm_config
fp8_dtype = chunked_tensors[0]._data.dtype
gemm_input_role = chunked_tensors[0]._gemm_input_role
chunk_data = []
Expand All @@ -97,7 +93,7 @@ def float8_cat(aten_op, args, kwargs=None):
chunk._scale is scale
), "Expecting all chunks to have thee same scale as a result of a split"
assert (
chunk._mm_config is mm_config
chunk._linear_mm_config is mm_config
), "Expecting all chunks to have thee same mm config as a result of a split"
assert (
chunk._data.dtype == fp8_dtype
Expand Down Expand Up @@ -139,16 +135,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):

scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._mm_config,
a._linear_mm_config,
b._gemm_input_role,
b._mm_config,
b._linear_mm_config,
)

if scaled_mm_config.pad_inner_dim:
# TODO(before land): assert this when choosing config
# assert (
# b._mm_config.pad_inner_dim
# ), "Both mm configs must have pad_inner_dim set to True"
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
Expand Down Expand Up @@ -177,9 +169,9 @@ def float8_mm(aten_op, args, kwargs=None):
output_dtype = a._orig_dtype
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._mm_config,
a._linear_mm_config,
b._gemm_input_role,
b._mm_config,
b._linear_mm_config,
)
if scaled_mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
Expand Down Expand Up @@ -213,9 +205,9 @@ def float8_addmm(aten_op, args, kwargs=None):
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
scaled_mm_config = choose_scaled_mm_config(
a._gemm_input_role,
a._mm_config,
a._linear_mm_config,
b._gemm_input_role,
b._mm_config,
b._linear_mm_config,
)
if scaled_mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
Expand Down Expand Up @@ -258,7 +250,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
args[0]._data,
args[0]._scale,
kwargs["dtype"],
args[0]._mm_config,
args[0]._linear_mm_config,
args[0]._gemm_input_role,
)

Expand All @@ -285,7 +277,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
)

Expand All @@ -301,7 +293,7 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_out,
fp8_input._scale,
fp8_input._orig_dtype,
fp8_input._mm_config,
fp8_input._linear_mm_config,
fp8_input._gemm_input_role,
)

Expand All @@ -323,7 +315,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
fp8_out,
fp8_self._scale,
fp8_self._orig_dtype,
fp8_self._mm_config,
fp8_self._linear_mm_config,
fp8_self._gemm_input_role,
)

Expand Down Expand Up @@ -351,7 +343,7 @@ def copy_fp8(aten_op, args, kwargs=None):
self._scale == src._scale
), "Expecting both Float8Tensors to have thee same scale"
assert (
self._mm_config == src._mm_config
self._linear_mm_config == src._linear_mm_config
), "Expecting both Float8Tensors to have thee same mm config"
assert (
self._data.dtype == src._data.dtype
Expand All @@ -364,7 +356,7 @@ def copy_fp8(aten_op, args, kwargs=None):
fp8_out,
self._scale,
self._orig_dtype,
self._mm_config,
self._linear_mm_config,
self._gemm_input_role,
)
else:
Expand Down
26 changes: 14 additions & 12 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def to_fp8_no_autograd(
local_bits,
local_scale,
x.dtype,
mm_config=linear_mm_config,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return DTensor.from_local(
Expand All @@ -181,7 +181,7 @@ def to_fp8_no_autograd(
bits_fp8,
x_scale,
x.dtype,
mm_config=linear_mm_config,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)

Expand Down Expand Up @@ -269,17 +269,15 @@ class Float8Tensor(torch.Tensor):
_data: torch.Tensor
_scale: torch.Tensor
_orig_dtype: torch.dtype
# TODO(before land): change this to _linear_mm_config, wanted to do that after
# initial review
_mm_config: LinearMMConfig
__slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"]
_linear_mm_config: LinearMMConfig
__slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"]

def __new__(
cls,
data: torch.Tensor,
scale: torch.Tensor,
orig_dtype: torch.dtype,
mm_config: Optional[LinearMMConfig],
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X,
):
assert (
Expand All @@ -301,18 +299,21 @@ def __new__(
self._data = data
self._scale = scale
self._orig_dtype = orig_dtype
self._mm_config = mm_config if mm_config is not None else LinearMMConfig()
self._linear_mm_config = (
linear_mm_config if linear_mm_config is not None else LinearMMConfig()
)
self._gemm_input_role = gemm_input_role

return self

def __repr__(self):
return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}"
return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}"

def __tensor_flatten__(self):
ctx = {
"_orig_dtype": self._orig_dtype,
"_mm_config": self._mm_config,
"_linear_mm_config": self._linear_mm_config,
"_gemm_input_role": self._gemm_input_role,
}
return ["_data", "_scale"], ctx

Expand All @@ -323,7 +324,8 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
inner_tensors["_data"],
inner_tensors["_scale"],
metadata["_orig_dtype"],
metadata["_mm_config"],
metadata["_linear_mm_config"],
metadata["_gemm_input_role"],
)

def to_original_precision(self):
Expand All @@ -346,7 +348,7 @@ def to_float8(
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion
mm_config: Defines the configuration for the scaled_mm
linearmm_config: Defines the configuration for 3 gemms in fwd/bwd of linear
Returns:
Float8Tensor: a float8 tensor
Expand Down
29 changes: 19 additions & 10 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
cast_to_float8_e5m2_dynamic_bw,
)
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_tensor import GemmInputRole
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
Expand Down Expand Up @@ -45,7 +46,9 @@ def _prepare_input_fn(
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor, mod.forward_config
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)

# transform the input layouts to the desired layouts of ColwiseParallel
Expand All @@ -64,7 +67,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
) # DTensor(torch.Tensor)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)

# back to local tensor
return outputs.to_local() if use_local_output else outputs
Expand Down Expand Up @@ -96,7 +99,9 @@ def _prepare_input_fn(
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor, mod.forward_config
input_tensor,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)

if input_layouts != desired_input_layouts:
Expand All @@ -114,7 +119,7 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me
outputs = outputs.redistribute(placements=output_layouts, async_op=True)

# fwd noop bwd cast to DTensor(Float8Tensor)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.backward_config)
outputs = cast_to_float8_e5m2_dynamic_bw(outputs, mod.linear_mm_config)

# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
Expand Down Expand Up @@ -191,7 +196,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
)

dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp, self.fwd_linear_config
dt_inp,
self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
) # DTensor(Float8Tensor)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
Expand All @@ -207,18 +214,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8Linear)
fwd_linear_config = fwd_linear.forward_config
linear_mm_config = fwd_linear.linear_mm_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8Linear):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
fwd_linear_config = mod.linear_mm_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8Linear modules should have same forward config!"
fwd_linear_config == mod.linear_mm_config
), "All the Float8Linear modules should have same linear_mm_config!"

self.fwd_linear_config = fwd_linear_config
self.linear_mm_config = fwd_linear_config
# TODO(this PR): something is broken here, fix it
assert self.linear_mm_config is not None
super()._apply(module, device_mesh)
return module
Loading

0 comments on commit bfe17f9

Please sign in to comment.