From c1700d76efda66e160fae0b8fbe95836362e0022 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 15 Jul 2024 13:53:33 -0700 Subject: [PATCH] [wip] make all 3 gemms in Float8Linear configurable Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1ef7943cbe21517d40975c69d0be4a719c7bf20d Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/315 --- float8_experimental/__init__.py | 8 +- float8_experimental/float8_dynamic_utils.py | 30 ++-- float8_experimental/float8_linear.py | 47 ++++-- float8_experimental/float8_ops.py | 97 ++++++++++--- float8_experimental/float8_tensor.py | 149 +++++++++++++++----- float8_experimental/fsdp_utils.py | 10 +- float8_experimental/inference.py | 2 + test/test_base.py | 103 +++++++++----- 8 files changed, 323 insertions(+), 123 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 8822796..85f4b32 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig +from float8_experimental.float8_tensor import ( + Float8Tensor, + GemmInputRole, + ScaledMMConfig, +) # Needed to load Float8Tensor with weights_only = True from torch.serialization import add_safe_globals -add_safe_globals([Float8Tensor, ScaledMMConfig]) +add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole]) __all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 9ad76f7..ccf6e5b 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -10,7 +10,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, - ScaledMMConfig, + GemmInputRole, + LinearMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -28,9 +29,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): def forward( ctx, tensor, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -39,21 +40,34 @@ def backward(ctx, gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, e5m2_dtype) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config + gradY, + gradY_scale, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) return fp8_tensor, None def cast_to_float8_e4m3_dynamic( - inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False + inpt_tensor: torch.Tensor, + linear_mm_config: LinearMMConfig, + reduce_amax: bool = False, + gemm_input_role: GemmInputRole = GemmInputRole.X, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) - return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + return Float8Tensor.to_float8( + inpt_tensor, + scale, + e4m3_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) def cast_to_float8_e5m2_dynamic_bw( - gradY: torch.Tensor, mm_config: ScaledMMConfig + gradY: torch.Tensor, linear_mm_config: LinearMMConfig ) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config) + return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 7850738..53cab47 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -23,6 +23,8 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, to_fp8_no_autograd, ) @@ -85,12 +87,12 @@ def forward( fp8_scale_dL_dY, scale_fn_name, is_amax_initialized, - mm_config: ScaledMMConfig, + linear_mm_config: LinearMMConfig, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized - ctx.mm_config = mm_config + ctx.linear_mm_config = linear_mm_config return tensor @staticmethod @@ -113,7 +115,11 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( - go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config + go, + fp8_scale_dL_dY, + e5m2_dtype, + linear_mm_config=ctx.linear_mm_config, + gemm_input_role=GemmInputRole.DL_DY, ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # Defines the behavior of the matmul in the forward and backward pass - self.forward_config = ScaledMMConfig( - emulate, True if not emulate else False, False, config.pad_inner_dim - ) - self.backward_config = ScaledMMConfig( - emulate, False, False, config.pad_inner_dim + # TODO(future): user level configuration of gemms + self.linear_mm_config = LinearMMConfig( + # x + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # w + ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ), + # dL_dY + ScaledMMConfig(emulate, False, False, config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent @@ -308,11 +320,12 @@ def cast_x_to_float8( self.fp8_scale_x, e4m3_dtype, self.fp8_amax_x, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.X, ) else: assert self.scaling_type_x is TensorScalingType.DYNAMIC - x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config) + x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) return x_fp8 def cast_w_to_float8( @@ -339,14 +352,17 @@ def cast_w_to_float8( self.fp8_scale_w, e4m3_dtype, self.fp8_amax_w, - self.forward_config, + linear_mm_config=self.linear_mm_config, + gemm_input_role=GemmInputRole.W, ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config) + w_fp8 = cast_to_float8_e4m3_dynamic( + self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W + ) return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: self.fp8_scale_dL_dY, scale_fn_name, self.is_amax_initialized, - self.backward_config, + self.linear_mm_config, ) else: assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC - y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config) + y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) return y def float8_pre_forward(self, x): @@ -455,6 +471,7 @@ def from_float( if config.enable_fsdp_fp8_all_gather: if scaling_type_w is TensorScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( + # TODO(this PR): change callsites below to linear_mm_config WeightWithDynamicFloat8CastTensor( new_mod.weight, new_mod.forward_config, diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3a50cc8..7859cfe 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -9,8 +9,8 @@ from float8_experimental.float8_python_api import addmm_float8_unwrapped from float8_experimental.float8_tensor import ( + choose_scaled_mm_config, Float8Tensor, - merge_mm_configs, ScaledMMConfig, ) from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul @@ -50,7 +50,11 @@ def decorator(func): def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) return Float8Tensor( - new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, + args[0]._gemm_input_role, ) @@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None): def make_float8(data): return Float8Tensor( - data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._mm_config, + args[0]._gemm_input_role, ) out = map(make_float8, new_data_tensors) @@ -76,6 +84,7 @@ def float8_cat(aten_op, args, kwargs=None): scale = chunked_tensors[0]._scale mm_config = chunked_tensors[0]._mm_config fp8_dtype = chunked_tensors[0]._data.dtype + gemm_input_role = chunked_tensors[0]._gemm_input_role chunk_data = [] for chunk in chunked_tensors: assert isinstance( @@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" + assert ( + chunk._gemm_input_role is gemm_input_role + ), "Expecting all chunks to have the same gemm_input_role as a result of a split" chunk_data.append(chunk._data.view(torch.uint8)) new_data = aten_op(chunk_data, *args[1:], **kwargs) new_data = new_data.view(fp8_dtype) - return Float8Tensor(new_data, scale, orig_dtype, mm_config) + return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) @implements([aten.sum.dim_IntList]) @@ -125,10 +137,18 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): a_scale = a._scale b_data = b._data - if a._mm_config.pad_inner_dim: - assert ( - b._mm_config.pad_inner_dim - ), "Both mm configs must have pad_inner_dim set to True" + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._mm_config, + b._gemm_input_role, + b._mm_config, + ) + + if scaled_mm_config.pad_inner_dim: + # TODO(before land): assert this when choosing config + # assert ( + # b._mm_config.pad_inner_dim + # ), "Both mm configs must have pad_inner_dim set to True" assert a._data.size(1) == b._data.size( 0 ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" @@ -155,10 +175,13 @@ def float8_mm(aten_op, args, kwargs=None): ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._mm_config, + b._gemm_input_role, + b._mm_config, + ) + if scaled_mm_config.emulate: return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -170,7 +193,7 @@ def float8_mm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=None, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -188,10 +211,13 @@ def float8_addmm(aten_op, args, kwargs=None): a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" - a_mm_config: ScaledMMConfig = a._mm_config - b_mm_config: ScaledMMConfig = b._mm_config - mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) - if mm_config.emulate: + scaled_mm_config = choose_scaled_mm_config( + a._gemm_input_role, + a._mm_config, + b._gemm_input_role, + b._mm_config, + ) + if scaled_mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype ) @@ -204,7 +230,7 @@ def float8_addmm(aten_op, args, kwargs=None): output_dtype, output_scale=None, bias=bias, - use_fast_accum=mm_config.use_fast_accum, + use_fast_accum=scaled_mm_config.use_fast_accum, ) return tensor_out @@ -229,7 +255,11 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config + args[0]._data, + args[0]._scale, + kwargs["dtype"], + args[0]._mm_config, + args[0]._gemm_input_role, ) @@ -252,7 +282,11 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_data = fp8_data.contiguous() fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, + fp8_input._gemm_input_role, ) @@ -264,7 +298,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config + fp8_out, + fp8_input._scale, + fp8_input._orig_dtype, + fp8_input._mm_config, + fp8_input._gemm_input_role, ) @@ -282,7 +320,11 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_values_data = fp8_values._data fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) return Float8Tensor( - fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + fp8_out, + fp8_self._scale, + fp8_self._orig_dtype, + fp8_self._mm_config, + fp8_self._gemm_input_role, ) @@ -314,7 +356,16 @@ def copy_fp8(aten_op, args, kwargs=None): assert ( self._data.dtype == src._data.dtype ), "Expecting both Float8Tensors to be of the same dtypet" + assert ( + self._gemm_input_role == src._gemm_input_role + ), "Expecting both Float8Tensors to have the same gemm_input_role" fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) - return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + return Float8Tensor( + fp8_out, + self._scale, + self._orig_dtype, + self._mm_config, + self._gemm_input_role, + ) else: raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 26d4688..6f58fdb 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import enum from collections import namedtuple from typing import Dict, Optional @@ -18,6 +19,31 @@ aten = torch.ops.aten +# +# A note on configuration of float8 logic in a linear +# TODO(future): move all the configs to separate file +# +# There are three gemms in a forward + backward of a Linear layer: +# +# 1. x @ w_t = y (forward pass) +# 2. dL_dY @ w = dL_dX (backward pass) +# 3. x_t @ dL_dY = dL_dW (backward pass) +# +# In the formulas above, there are: +# A. six input tensors (x, x_t, w, w_t, dL_dY, dL_dY_t). +# - Note that dL_dY_t is implied because of memory format requirements +# of float8 gemms +# B. three output tensors (y, dL_dX, dL_dW) +# +# We want each input tensor, gemm, and output tensor to be configurable. +# The state of this configuration today is: +# +# i. pairs of input tensors (non-t and t variants) have their scaling +# configurable via the scaling_type_{x_w_dL_dY} arguments to Float8Linear +# ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing +# iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed +# to configure all three gemms, also not user facing + # ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. # emulate: whether to emulate the matmuls in fp32 @@ -30,27 +56,55 @@ defaults=[False, False, False, False], ) +# The object below exists for convenience, to allow Float8Tensor to use +# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is +# being called. +LinearMMConfig = namedtuple( + "LinearMMConfig", + ["y", "dL_dX", "dL_dW"], + defaults=[ + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ], +) -def merge_mm_configs( - a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig -) -> ScaledMMConfig: - """Merges two mm_configs together emulate behavior must match, - However we want to use_fast_accum in forward and not in backward. - We do this by populating the fields of the backproping grad. Same applies for fp8_output. - For both use_fast_accum and fp8_output, if either config is False, the merged config will be False. - """ - assert ( - a_mm_config.emulate == b_mm_config.emulate - ), "Both mm_configs must have the same emulate value, but got {} and {}".format( - a_mm_config.emulate, b_mm_config.emulate - ) - return ScaledMMConfig( - emulate=a_mm_config.emulate, - use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum, - fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output, - pad_inner_dim=a_mm_config.pad_inner_dim and b_mm_config.pad_inner_dim, - ) +# Given a Float8Tensor, the enum below describes the expected role of this +# tensor in the three gemms present in the fw + bw pass of a Linear layer. +# This is used to choose the right config for a float8 gemm when the +# gemm is performed. +class GemmInputRole(enum.Enum): + X = "x" + W = "w" + DL_DY = "dL_dY" + + +# choose which scaled_mm_config to use based on gemm inputs +def choose_scaled_mm_config( + a_role: GemmInputRole, + a_linear_mm_config: LinearMMConfig, + b_role: GemmInputRole, + b_linear_mm_config: LinearMMConfig, +): + if a_role is GemmInputRole.X and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.y == b_linear_mm_config.y + ), f"linear_mm_config.y mismatch: {a_linear_mm_config.y} vs {b_linear_mm_config.y}" + return a_linear_mm_config.y + elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.W: + assert ( + a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX + ), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}" + return a_linear_mm_config.dL_dX + else: + assert ( + a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X + ), f"unexpected a_role {a_role} and b_role {b_role}" + assert ( + a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW + ), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}" + return a_linear_mm_config.dL_dW def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: @@ -72,7 +126,8 @@ def to_fp8_no_autograd( x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + linear_mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole], ) -> "Float8Tensor": """Convert a tensor to float8 without autograd This is used in multiple places in the codebase to convert a tensor to float8 @@ -90,7 +145,10 @@ def to_fp8_no_autograd( x: the tensor to convert scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use - mm_config: Defines the configuration for the scaled_mm + linear_mm_config: Defines the configuration for the scaled_mm for + the 3 fwd/bwd gemms of linear + gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in + the 3 fwd/bwd gemms of linear """ x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -104,7 +162,11 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, mm_config=mm_config + local_bits, + local_scale, + x.dtype, + mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, ) return DTensor.from_local( inner_float8_tensor, @@ -115,7 +177,13 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) + return Float8Tensor( + bits_fp8, + x_scale, + x.dtype, + mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @torch._dynamo.allow_in_graph @@ -133,7 +201,8 @@ def forward( scale: torch.Tensor, float8_dtype=e4m3_dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -146,11 +215,17 @@ def forward( if amax_buffer is not None: amax_buffer.fill_(tensor_to_amax(tensor)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) + return to_fp8_no_autograd( + tensor, + scale, + float8_dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -194,7 +269,9 @@ class Float8Tensor(torch.Tensor): _data: torch.Tensor _scale: torch.Tensor _orig_dtype: torch.dtype - _mm_config: ScaledMMConfig + # TODO(before land): change this to _linear_mm_config, wanted to do that after + # initial review + _mm_config: LinearMMConfig __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"] def __new__( @@ -202,7 +279,8 @@ def __new__( data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtype, - mm_config: Optional[ScaledMMConfig], + mm_config: Optional[LinearMMConfig], + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): assert ( scale.numel() == 1 @@ -223,12 +301,13 @@ def __new__( self._data = data self._scale = scale self._orig_dtype = orig_dtype - self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() + self._mm_config = mm_config if mm_config is not None else LinearMMConfig() + self._gemm_input_role = gemm_input_role return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { @@ -257,7 +336,8 @@ def to_float8( scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, - mm_config: Optional[ScaledMMConfig] = None, + linear_mm_config: Optional[LinearMMConfig] = None, + gemm_input_role: Optional[GemmInputRole] = GemmInputRole.X, ): """Converts a higher precision tensor to float8 in a differentiable way. @@ -272,7 +352,12 @@ def to_float8( Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, scale, float8_dtype, amax_buffer, mm_config + tensor, + scale, + float8_dtype, + amax_buffer, + linear_mm_config, + gemm_input_role, ) @classmethod diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 81d53b5..444ea27 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -12,11 +12,7 @@ import torch.utils._pytree as pytree from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic -from float8_experimental.float8_tensor import ( - Float8Tensor, - merge_mm_configs, - ScaledMMConfig, -) +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_utils import e4m3_dtype, EPS from torch._prims_common import suggest_memory_format @@ -129,7 +125,7 @@ def unwrap(t): if mm_config is None: mm_config = t._mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._mm_config == mm_config return t._tensor args, kwargs = pytree.tree_map_only( @@ -257,7 +253,7 @@ def unwrap(t): if mm_config is None: mm_config = t._mm_config else: - mm_config = merge_mm_configs(mm_config, t._mm_config) + assert t._mm_config == mm_config nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931ee..e329671 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -20,6 +20,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + GemmInputRole, ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, @@ -126,6 +127,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: scale, dtype, self.forward_config, + gemm_input_role=GemmInputRole.W, ) self.weight = nn.Parameter(quantized_weight) self.weight.requires_grad = False diff --git a/test/test_base.py b/test/test_base.py index 2c7c3f4..8e1d1cc 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -26,7 +26,8 @@ from float8_experimental.float8_python_api import addmm_float8_unwrapped from float8_experimental.float8_tensor import ( Float8Tensor, - merge_mm_configs, + GemmInputRole, + LinearMMConfig, ScaledMMConfig, ) from float8_experimental.float8_utils import ( @@ -438,38 +439,36 @@ def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") fp8_dtype = e4m3_dtype - a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) + linear_config_a = LinearMMConfig( + ScaledMMConfig(False, True, False, False), + ScaledMMConfig(False, False, False, False), + ScaledMMConfig(False, False, False, False), + ) + linear_config_b = LinearMMConfig( + ScaledMMConfig(True, True, False, False), + ScaledMMConfig(True, False, False, False), + ScaledMMConfig(True, False, False, False), + ) + a = Float8Tensor.to_float8( + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_a, + gemm_input_role=GemmInputRole.X, + ) b = Float8Tensor.to_float8( - x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) + x_fp32, + x_scale, + fp8_dtype, + linear_mm_config=linear_config_b, + gemm_input_role=GemmInputRole.W, ) with pytest.raises( AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", + match="linear_mm_config.y mismatch", ): a @ b - def test_merge_configs(self): - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(True, False, False) - with pytest.raises( - AssertionError, - match="Both mm_configs must have the same emulate value, but got False and True", - ): - merge_mm_configs(a, b) - a = ScaledMMConfig(False, True, True) - b = ScaledMMConfig(False, False, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is False - assert c.fp8_output is False - - a = ScaledMMConfig(False, True, False) - b = ScaledMMConfig(False, True, False) - c = merge_mm_configs(a, b) - assert c.emulate is False - assert c.use_fast_accum is True - assert c.fp8_output is False - @unittest.skipIf( not is_H100, "CUDA not available", @@ -489,8 +488,12 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + a_fp8 = Float8Tensor.to_float8( + a, a_scale, input_dtype, gemm_input_role=GemmInputRole.X + ) + b_fp8 = Float8Tensor.to_float8( + b, b_scale, input_dtype, gemm_input_role=GemmInputRole.W + ) with pytest.raises( RuntimeError, @@ -500,19 +503,47 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): ): a_fp8 @ b_fp8 - pad_config = ScaledMMConfig(False, use_fast_accum, False, True) + scaled_mm_config = ScaledMMConfig(False, use_fast_accum, False, True) + pad_config = LinearMMConfig( + scaled_mm_config, scaled_mm_config, scaled_mm_config + ) - a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) - b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) + a_fp8 = Float8Tensor.to_float8( + a, + a_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.X, + ) + b_fp8 = Float8Tensor.to_float8( + b, + b_scale, + input_dtype, + linear_mm_config=pad_config, + gemm_input_role=GemmInputRole.W, + ) out_padded = a_fp8 @ b_fp8 out_padded.to(compare_type) - emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_scaled_mm_config = ScaledMMConfig(True, use_fast_accum, False, False) + emulated_config = LinearMMConfig( + emulated_scaled_mm_config, + emulated_scaled_mm_config, + emulated_scaled_mm_config, + ) a_fp8 = Float8Tensor.to_float8( - a, a_scale, input_dtype, mm_config=emulated_conifg + a, + a_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.X, ) b_fp8 = Float8Tensor.to_float8( - b, b_scale, input_dtype, mm_config=emulated_conifg + b, + b_scale, + input_dtype, + linear_mm_config=emulated_config, + gemm_input_role=GemmInputRole.W, ) out_emualted = a_fp8 @ b_fp8 out_emualted.to(compare_type) @@ -564,8 +595,8 @@ def test_swap_root_linear(self): module = nn.Linear(3, 3) module = swap_linear_with_float8_linear(module, emulate=emulate) self.assertIsInstance(module, Float8Linear) - self.assertEqual(module.forward_config.emulate, emulate) - self.assertEqual(module.backward_config.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) + self.assertEqual(module.linear_mm_config.y.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: