Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "[wip] make all 3 gemms in Float8Linear configurable"
Browse files Browse the repository at this point in the history
Summary:

not ready for review yet

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 19, 2024
1 parent bfe17f9 commit 57a4b32
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
21 changes: 12 additions & 9 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
defaults=[False, False, False, False],
)

# The object below exists for convenience, to allow Float8Tensor to use
# The object below is not user facing and exists for convenience,
# to allow Float8Tensor to use
# the right config based on which gemm from `y`, `dL_dX`, `dL_dW` is
# being called.
LinearMMConfig = namedtuple(
Expand All @@ -70,11 +71,14 @@
)


# Given a Float8Tensor, the enum below describes the expected role of this
# tensor in the three gemms present in the fw + bw pass of a Linear layer.
# This is used to choose the right config for a float8 gemm when the
# gemm is performed.
class GemmInputRole(enum.Enum):
"""
Given a Float8Tensor, the enum below describes the expected role of this
tensor in the three gemms present in the fw + bw pass of a Linear layer.
This is used to choose the right config for a float8 gemm when the
gemm is performed.
"""

X = "x"
W = "w"
DL_DY = "dL_dY"
Expand All @@ -97,14 +101,13 @@ def choose_scaled_mm_config(
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
return a_linear_mm_config.dL_dX
else:
assert (
a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X
), f"unexpected a_role {a_role} and b_role {b_role}"
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
assert (
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"
return a_linear_mm_config.dL_dW
else:
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
Expand Down
12 changes: 5 additions & 7 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(

# fp8 specific fields
self.float8_dtype = float8_dtype
self.linear_mm_config = None
self.fwd_config_submodule_fqn = fwd_config_submodule_fqn

if self.float8_dtype != torch.float8_e4m3fn:
Expand Down Expand Up @@ -210,24 +211,21 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_linear import Float8Linear

fwd_linear_config = None
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8Linear)
linear_mm_config = fwd_linear.linear_mm_config
self.linear_mm_config = fwd_linear.linear_mm_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8Linear):
if fwd_linear_config is None:
fwd_linear_config = mod.linear_mm_config
if self.linear_mm_config is None:
self.linear_mm_config = mod.linear_mm_config
else:
assert (
fwd_linear_config == mod.linear_mm_config
self.linear_mm_config == mod.linear_mm_config
), "All the Float8Linear modules should have same linear_mm_config!"

self.linear_mm_config = fwd_linear_config
# TODO(this PR): something is broken here, fix it
assert self.linear_mm_config is not None
super()._apply(module, device_mesh)
return module

0 comments on commit 57a4b32

Please sign in to comment.