Skip to content

Commit

Permalink
add unit test for block wise fp8 (#3156)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 authored Jan 27, 2025
1 parent fb11a43 commit 1e3e521
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"test_w8a8_quantization.py",
"test_session_control.py",
"test_fp8_kvcache.py",
"test_fp8_kernel.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
Expand Down
129 changes: 129 additions & 0 deletions test/srt/test_fp8_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unittest

import torch

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)


class TestFP8Base(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.M = 256
# test non-aligned
cls.N = 1024 + 64
cls.K = 512
cls.group_size = 128
cls.quant_type = torch.float8_e4m3fn
cls.output_type = torch.float16

@staticmethod
def _make_A(M, K, group_size, out_dtype):
quant_A = torch.rand(
M, K // group_size, group_size, dtype=torch.float32, device="cuda"
)
# -1 ~ 1
quant_A = quant_A * 2 - 1
# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
quant_A *= scaling
quant_A = quant_A.to(out_dtype).to(torch.float32)

# create scale and A
scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda")
scale /= fmax
A = quant_A * scale[..., None]

A = A.reshape(M, K)
quant_A = quant_A.reshape(M, K).to(out_dtype)
return A, quant_A, scale

@staticmethod
def _make_B(K, N, group_size, out_dtype):
def _aligned_size(a, b):
return (a + b - 1) // b * b

K_aligned = _aligned_size(K, group_size)
N_aligned = _aligned_size(N, group_size)

quant_B = torch.rand(
K_aligned // group_size,
group_size,
N_aligned // group_size,
group_size,
dtype=torch.float32,
device="cuda",
)
quant_B = quant_B * 2 - 1

# scaling abs max to fmax
finfo = torch.finfo(out_dtype)
fmax = finfo.max
scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True)
quant_B *= scaling
quant_B = quant_B.to(out_dtype).to(torch.float32)

scale = torch.rand(
K_aligned // group_size,
1,
N_aligned // group_size,
1,
dtype=torch.float32,
device="cuda",
)
scale /= fmax

B = quant_B * scale

B = B.reshape(K_aligned, N_aligned)[:K, :N]
quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N]
scale = scale.reshape(K_aligned // group_size, N_aligned // group_size)
return B, quant_B, scale


class TestPerTokenGroupQuantFP8(TestFP8Base):
def test_per_token_group_quant_fp8(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
A_quant, scale = per_token_group_quant_fp8(
x=A, group_size=self.group_size, dtype=self.quant_type
)
torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero()
assert diff_count / diff.numel() < 1e-4


class TestW8A8BlockFP8Matmul(TestFP8Base):
def test_w8a8_block_fp8_matmul(self):
if torch.cuda.get_device_capability()[0] < 9:
return
A, A_quant_gt, A_scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
)
B, B_quant_gt, B_scale_gt = self._make_B(
K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type
)
C_gt = A.to(self.output_type) @ B.to(self.output_type)
C = w8a8_block_fp8_matmul(
A=A_quant_gt,
B=B_quant_gt.T.contiguous(),
As=A_scale_gt,
Bs=B_scale_gt.T.contiguous(),
block_size=[128, 128],
output_dtype=self.output_type,
)
torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1e3e521

Please sign in to comment.