Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[1/x] clean up casting functions #339

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import fire

import torch
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
Expand Down Expand Up @@ -50,9 +55,25 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)

a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)

a_fp8 = ToFloat8ConstrFunc.apply(
A,
scale_a,
fp8_dtype,
None, # amax_buffer
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
B,
scale_b,
fp8_dtype,
None, # amax_buffer
b_config,
GemmInputRole.WEIGHT,
)

return a_fp8 @ b_fp8

Expand Down
71 changes: 0 additions & 71 deletions float8_experimental/float8_dynamic_utils.py

This file was deleted.

112 changes: 10 additions & 102 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,29 @@

from float8_experimental.config import Float8LinearConfig, ScalingType

from float8_experimental.float8_dynamic_utils import (
from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)

from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
)
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax

from float8_experimental.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
x,
cur_amax,
amax_history,
scale,
scale_fn_name,
float8_dtype,
is_initialized,
reduce_amax,
):
"""
If x is about to be cast to `float8` and the amax buffers are not initialized,
initializes them inplace.
"""
if is_initialized:
return
with torch.no_grad():
# Note: we need to enable distributed reduction here in order
# to match numerics between single GPU and multi GPU code for
# activations and gradients
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
amax_history, float8_dtype, x.dtype, scale_fn_name
)
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
@torch._dynamo.allow_in_graph
class manual_float8_matmul(torch.autograd.Function):
Expand Down Expand Up @@ -127,66 +95,6 @@ def backward(ctx, grad_output_fp8):
return grad_input, grad_weight.t()


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""
Forward: no-op
Backward: convert to float8_e5m2, initialize if needed
"""

@staticmethod
def forward(
ctx,
tensor,
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
scale_fn_name,
is_amax_initialized,
linear_mm_config: LinearMMConfig,
):
ctx.save_for_backward(
fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output
)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.linear_mm_config = linear_mm_config
return tensor

@staticmethod
def backward(ctx, go):
(
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
) = ctx.saved_tensors
scale_fn_name = ctx.scale_fn_name
is_amax_initialized = ctx.is_amax_initialized

_maybe_initialize_amaxes_scales_for_float8_cast(
go,
fp8_amax_grad_output,
fp8_amax_history_grad_output,
fp8_scale_grad_output,
scale_fn_name,
e5m2_dtype,
is_amax_initialized,
reduce_amax=True,
)

fp8_amax_grad_output.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
go,
fp8_scale_grad_output,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads


class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
Expand Down Expand Up @@ -352,7 +260,7 @@ def cast_input_to_float8(
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = Float8Tensor.to_float8(
input_fp8 = cast_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
Expand Down Expand Up @@ -384,7 +292,7 @@ def cast_weight_to_float8(
reduce_amax=False,
)

weight_fp8 = Float8Tensor.to_float8(
weight_fp8 = cast_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
Expand All @@ -407,7 +315,7 @@ def cast_weight_to_float8(
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2Bw.apply(
output = NoopFwToFloat8E5M2BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
Expand All @@ -418,7 +326,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
)
else:
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
return output

def float8_pre_forward(self, input):
Expand Down
Loading
Loading