Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fix numerics integration test and test delayed vs dynamic (#291)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #291

1. the SAM test wasn't easy to use because it had real weights and hence
   required real data for useful testing, which is not convenient from
   an integration test. Switched to LLaMa FFN with random weights, and
   made all the thresholds tight to actually check numerics are close.
2. extended numerics test to check all combinations of delayed vs
   dynamic
3. to be able to do (2), extended the module swap utility to configure
   delayed vs dynamic on a model level, for now without an option to
   customize further

Reviewed By: drisspg

Differential Revision: D59305796

fbshipit-source-id: 4b1cd097ff82ce81a774cab535b0c890d47a2ae8
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 3, 2024
1 parent 3cb42e1 commit 1e71def
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 81 deletions.
33 changes: 32 additions & 1 deletion float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,41 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
) -> Optional[nn.Module]:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`.
Args:
module: Module to modify.
module_cls: `Float8Linear` or `Float8DynamicLinear`.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
emulate: If True, emulation is used instead of hardware accelerated gemm
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
scaling_type_x (TensorScalingType): scaling type for `x`
scaling_type_w (TensorScalingType): scaling type for `w`
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if module_cls is Float8DynamicLinear:
from_float = lambda m: module_cls.from_float(m, emulate=emulate)
else:
from_float = lambda m: module_cls.from_float(
m,
emulate=emulate,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
from_float,
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ dependencies = [

[project.optional-dependencies]
test = [
"transformers==4.38.2",
"pandas >= 2.0",
"tqdm==4.66.2",
"fire==0.5.0",
Expand Down
2 changes: 1 addition & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ set -e
IS_ROCM=$(rocm-smi --version || true)

pytest test/test_base.py
pytest test/test_sam.py
pytest test/test_compile.py
pytest test/test_inference_flows.py
pytest test/test_numerics_integration.py

# These tests do not work on ROCm yet
if [ -z "$IS_ROCM" ]
Expand Down
191 changes: 191 additions & 0 deletions test/test_numerics_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

# Tests LLaMa FeedForward numerics with float8

import copy
from typing import Optional

import pytest

import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_utils import compute_error, IS_ROCM

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


torch.manual_seed(0)


# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
class FeedForward(nn.Module):
"""
FeedForward module
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (Linear): Linear transformation for the first layer.
w2 (Linear): Linear transformation for the second layer.
w3 (Linear): Linear transformation for the third layer.
"""

def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
for linear in (self.w2, self.w3):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class TestFloat8NumericsIntegrationTest:
@pytest.mark.parametrize(
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize(
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
)
@pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear])
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU")
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack")
def test_encoder_fw_bw(
self,
linear_cls,
scaling_type_x: TensorScalingType,
scaling_type_w: TensorScalingType,
scaling_type_dL_dY: TensorScalingType,
):
linear_type = (
LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC
)
if linear_type is LinearType.DYNAMIC:
# Only test one combination of scaling types, as they are a no-op
# for Float8DynamicLinear. It would be cleaner to split into two
# tests, but IMO not worth it since Float8DynamicLinear will be
# deleted soon
is_all_dynamic = (
scaling_type_x is TensorScalingType.DYNAMIC
and scaling_type_w is TensorScalingType.DYNAMIC
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
)
if not is_all_dynamic:
pytest.skip()

# TODO(later): maybe add float16 back if it becomes important
data_dtype = torch.bfloat16

# LLaMa 3 70B shapes
model_ref = (
FeedForward(
dim=4096,
hidden_dim=16384,
multiple_of=1024,
ffn_dim_multiplier=1.3,
)
.cuda()
.to(data_dtype)
)

# for now just test the encoder to simplify things
model_fp8 = copy.deepcopy(model_ref)
swap_linear_with_float8_linear(
model_fp8,
linear_cls,
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)

lr = 0.01
optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr)
optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)

# Note: you need two different inputs to properly test numerics
# of delayed scaling, because the first time around the initialization
# logic of delayed scaling behaves as dynamic scaling
# TODO(future): also make unit tests do this properly
shape = (1, 8192, 4096)
data1 = torch.randn(*shape, device="cuda", dtype=data_dtype)
data2 = torch.randn(*shape, device="cuda", dtype=data_dtype)

model_ref(data1).sum().backward()
# zero out grads without stepping, since we just want to compare grads
# of the second datum
optim_ref.zero_grad()
model_ref_out = model_ref(data2)
model_ref_out.sum().backward()

if linear_requires_sync(
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
sync_float8_amax_and_scale_history(model_fp8)
model_fp8(data1).sum().backward()
# zero out grads without stepping, since we just want to compare grads
# of the second datum
optim_fp8.zero_grad()
if linear_requires_sync(
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
sync_float8_amax_and_scale_history(model_fp8)
model_fp8_out = model_fp8(data2)
model_fp8_out.sum().backward()

out_sqnr = compute_error(model_ref_out, model_fp8_out)
assert out_sqnr > 20.0

ref_name_to_grad = {
name: param.grad for name, param in model_ref.named_parameters()
}

grad_sqnr_threshold = 20.0

for name, param in model_fp8.named_parameters():
ref_grad = ref_name_to_grad[name]
cur_grad = param.grad
sqnr = compute_error(ref_grad, cur_grad)
assert sqnr > grad_sqnr_threshold


if __name__ == "__main__":
pytest.main([__file__])
78 changes: 0 additions & 78 deletions test/test_sam.py

This file was deleted.

0 comments on commit 1e71def

Please sign in to comment.