This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix numerics integration test and test delayed vs dynamic (#291)
Summary: Pull Request resolved: #291 1. the SAM test wasn't easy to use because it had real weights and hence required real data for useful testing, which is not convenient from an integration test. Switched to LLaMa FFN with random weights, and made all the thresholds tight to actually check numerics are close. 2. extended numerics test to check all combinations of delayed vs dynamic 3. to be able to do (2), extended the module swap utility to configure delayed vs dynamic on a model level, for now without an option to customize further Reviewed By: drisspg Differential Revision: D59305796 fbshipit-source-id: 4b1cd097ff82ce81a774cab535b0c890d47a2ae8
- Loading branch information
1 parent
3cb42e1
commit 1e71def
Showing
5 changed files
with
224 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Tests LLaMa FeedForward numerics with float8 | ||
|
||
import copy | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_linear_utils import ( | ||
linear_requires_sync, | ||
LinearType, | ||
swap_linear_with_float8_linear, | ||
sync_float8_amax_and_scale_history, | ||
) | ||
from float8_experimental.float8_utils import compute_error, IS_ROCM | ||
|
||
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) | ||
|
||
|
||
torch.manual_seed(0) | ||
|
||
|
||
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py | ||
class FeedForward(nn.Module): | ||
""" | ||
FeedForward module | ||
Args: | ||
dim (int): Input dimension. | ||
hidden_dim (int): Hidden dimension of the feedforward layer. | ||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value. | ||
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. | ||
Attributes: | ||
w1 (Linear): Linear transformation for the first layer. | ||
w2 (Linear): Linear transformation for the second layer. | ||
w3 (Linear): Linear transformation for the third layer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
hidden_dim: int, | ||
multiple_of: int, | ||
ffn_dim_multiplier: Optional[float], | ||
): | ||
super().__init__() | ||
hidden_dim = int(2 * hidden_dim / 3) | ||
# custom dim factor multiplier | ||
if ffn_dim_multiplier is not None: | ||
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | ||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
||
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | ||
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | ||
self.w3 = nn.Linear(dim, hidden_dim, bias=False) | ||
|
||
def forward(self, x): | ||
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
||
def init_weights(self, init_std: float): | ||
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) | ||
for linear in (self.w2, self.w3): | ||
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) | ||
|
||
|
||
class TestFloat8NumericsIntegrationTest: | ||
@pytest.mark.parametrize( | ||
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize( | ||
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] | ||
) | ||
@pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) | ||
@pytest.mark.skipif(not is_H100, reason="requires H100 GPU") | ||
@pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") | ||
def test_encoder_fw_bw( | ||
self, | ||
linear_cls, | ||
scaling_type_x: TensorScalingType, | ||
scaling_type_w: TensorScalingType, | ||
scaling_type_dL_dY: TensorScalingType, | ||
): | ||
linear_type = ( | ||
LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC | ||
) | ||
if linear_type is LinearType.DYNAMIC: | ||
# Only test one combination of scaling types, as they are a no-op | ||
# for Float8DynamicLinear. It would be cleaner to split into two | ||
# tests, but IMO not worth it since Float8DynamicLinear will be | ||
# deleted soon | ||
is_all_dynamic = ( | ||
scaling_type_x is TensorScalingType.DYNAMIC | ||
and scaling_type_w is TensorScalingType.DYNAMIC | ||
and scaling_type_dL_dY is TensorScalingType.DYNAMIC | ||
) | ||
if not is_all_dynamic: | ||
pytest.skip() | ||
|
||
# TODO(later): maybe add float16 back if it becomes important | ||
data_dtype = torch.bfloat16 | ||
|
||
# LLaMa 3 70B shapes | ||
model_ref = ( | ||
FeedForward( | ||
dim=4096, | ||
hidden_dim=16384, | ||
multiple_of=1024, | ||
ffn_dim_multiplier=1.3, | ||
) | ||
.cuda() | ||
.to(data_dtype) | ||
) | ||
|
||
# for now just test the encoder to simplify things | ||
model_fp8 = copy.deepcopy(model_ref) | ||
swap_linear_with_float8_linear( | ||
model_fp8, | ||
linear_cls, | ||
emulate=False, | ||
scaling_type_x=scaling_type_x, | ||
scaling_type_w=scaling_type_w, | ||
scaling_type_dL_dY=scaling_type_dL_dY, | ||
) | ||
|
||
lr = 0.01 | ||
optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) | ||
optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) | ||
|
||
# Note: you need two different inputs to properly test numerics | ||
# of delayed scaling, because the first time around the initialization | ||
# logic of delayed scaling behaves as dynamic scaling | ||
# TODO(future): also make unit tests do this properly | ||
shape = (1, 8192, 4096) | ||
data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) | ||
data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) | ||
|
||
model_ref(data1).sum().backward() | ||
# zero out grads without stepping, since we just want to compare grads | ||
# of the second datum | ||
optim_ref.zero_grad() | ||
model_ref_out = model_ref(data2) | ||
model_ref_out.sum().backward() | ||
|
||
if linear_requires_sync( | ||
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY | ||
): | ||
sync_float8_amax_and_scale_history(model_fp8) | ||
model_fp8(data1).sum().backward() | ||
# zero out grads without stepping, since we just want to compare grads | ||
# of the second datum | ||
optim_fp8.zero_grad() | ||
if linear_requires_sync( | ||
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY | ||
): | ||
sync_float8_amax_and_scale_history(model_fp8) | ||
model_fp8_out = model_fp8(data2) | ||
model_fp8_out.sum().backward() | ||
|
||
out_sqnr = compute_error(model_ref_out, model_fp8_out) | ||
assert out_sqnr > 20.0 | ||
|
||
ref_name_to_grad = { | ||
name: param.grad for name, param in model_ref.named_parameters() | ||
} | ||
|
||
grad_sqnr_threshold = 20.0 | ||
|
||
for name, param in model_fp8.named_parameters(): | ||
ref_grad = ref_name_to_grad[name] | ||
cur_grad = param.grad | ||
sqnr = compute_error(ref_grad, cur_grad) | ||
assert sqnr > grad_sqnr_threshold | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
This file was deleted.
Oops, something went wrong.