Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

make all 3 gemms in Float8Linear support configurability, not user facing #315

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
30 changes: 22 additions & 8 deletions float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
def forward(
ctx,
tensor,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -37,21 +38,34 @@ def backward(ctx, gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
gradY,
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
return fp8_tensor, None


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.X,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe no default value for gemm role?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can clean up in a separate PR, there is extra complexity because we'd need to change the argument order

) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
e4m3_dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)


def cast_to_float8_e5m2_dynamic_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
gradY: torch.Tensor, linear_mm_config: LinearMMConfig
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config)
50 changes: 33 additions & 17 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
to_fp8_no_autograd,
)
Expand Down Expand Up @@ -85,12 +87,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
mm_config: ScaledMMConfig,
linear_mm_config: LinearMMConfig,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.mm_config = mm_config
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
Expand All @@ -113,7 +115,11 @@ def backward(ctx, go):
fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
go,
fp8_scale_dL_dY,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.DL_DY,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand Down Expand Up @@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):

self.create_buffers()

# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
# TODO(future): user level configuration of gemms
self.linear_mm_config = LinearMMConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[I think this might be another stylistic thing so no need to change]:

I think I would actually make this a func and then super document it. Its not very clear reading this what everything does so I would clearly explain in that func the exact recipe that we choose by default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't user facing, so we can clean up at any time

# x
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# w
ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
),
# dL_dY
ScaledMMConfig(emulate, False, False, config.pad_inner_dim),
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
Expand Down Expand Up @@ -308,11 +320,12 @@ def cast_x_to_float8(
self.fp8_scale_x,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.X,
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
return x_fp8

def cast_w_to_float8(
Expand All @@ -339,14 +352,17 @@ def cast_w_to_float8(
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.W,
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
self.linear_mm_config,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -457,7 +473,7 @@ def from_float(
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
new_mod.linear_mm_config,
)
)
else:
Expand All @@ -468,7 +484,7 @@ def from_float(
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.linear_mm_config,
new_mod.is_amax_initialized,
)
)
Expand Down
Loading
Loading