diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index afe53797322f9..c3eddacec2727 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,7 +2,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ -from typing import Optional, Type +from typing import Type import pytest import torch @@ -11,6 +11,8 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform +from .utils import baseline_scaled_mm, to_fp8, to_int8 + MNK_FACTORS = [ (1, 256, 128), (1, 16384, 1024), @@ -41,34 +43,10 @@ capability = capability[0] * 10 + capability[1] -def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor): - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - def rand_int8(shape: tuple, device: str = "cuda"): return to_int8(torch.rand(shape, device=device) * 255 - 128) -def baseline_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = (scale_a * (scale_b * (torch.mm( - a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) - if bias is not None: - output = output + bias - - return output - - def cutlass_fp8_gemm_helper(m: int, n: int, k: int, diff --git a/tests/kernels/test_cutlass_2of4_sparse.py b/tests/kernels/test_cutlass_2of4_sparse.py new file mode 100644 index 0000000000000..56495df34aa6c --- /dev/null +++ b/tests/kernels/test_cutlass_2of4_sparse.py @@ -0,0 +1,214 @@ +"""Tests for sparse cutlass kernels + +Run `pytest tests/kernels/test_semi_structured.py`. +""" +from typing import Tuple, Type + +import pytest +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + sparse_cutlass_supported) +from vllm.platforms import current_platform + +from .utils import baseline_scaled_mm, to_fp8, to_int8 + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = current_platform.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.") +# Test working with a subset of A and B for sparse matmul +def test_cutlass_sparse_subset(): + + big_m = 1024 + m, n, k = 512, 512, 512 + + # Create tensors + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, + big_m, n, k) + a = whole_a[0:m, 0:k] + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +MNK_FACTORS = [ + (1, 256, 128), + (1, 16384, 1024), + (1, 24576, 512), + (16, 256, 512), + (16, 16384, 128), + (16, 24576, 4096), + (32, 8192, 4096), + (32, 16384, 4096), + (33, 1024, 1024), + (33, 8192, 128), + (64, 2048, 512), + (64, 16384, 1024), + (100, 8192, 512), + (128, 32768, 4096), + (256, 4096, 4096), + (512, 256, 1024), + (512, 8192, 4096), + (512, 16384, 128), + (512, 24576, 128), +] + + +# Test working with a subset of A and B for sparse matmul +@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.parametrize("m, k, n", MNK_FACTORS) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]): + + # Create tensors + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32) + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=dtype) + baseline = F.linear(a, b.T) + + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.parametrize("m, k, n", MNK_FACTORS) +@pytest.mark.skipif(not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.") +def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int): + + # Create tensors + b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) + + +@pytest.mark.skipif(not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.parametrize("m,k,n", MNK_FACTORS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): + + # Create tensors + b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) diff --git a/tests/kernels/test_semi_structured.py b/tests/kernels/test_semi_structured.py deleted file mode 100644 index 4316d6ab30e33..0000000000000 --- a/tests/kernels/test_semi_structured.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Tests for sparse cutlass kernels - -Run `pytest tests/kernels/test_semi_structured.py`. -""" -from typing import Optional, Tuple, Type - -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) -from vllm.platforms import current_platform - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -capability = current_platform.get_device_capability() -capability = capability[0] * 10 + capability[1] - - -def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor): - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def rand_int8(shape: tuple, device: str = "cuda"): - return to_int8(torch.rand(shape, device=device) * 255 - 128) - - -def to_bf16(tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(dtype=torch.bfloat16) - - -def to_fp16(tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(dtype=torch.float16) - - -def prune_to_2_4(tensor): - # Reshape tensor to [N, 4] where N is number of groups of 4 - original_shape = tensor.shape - reshaped = tensor.reshape(-1, 4) - - # Get indices of top 2 absolute values in each group of 4 - _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) - - # Create binary mask - mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) - - # Apply mask and reshape back - pruned = reshaped * mask - - # Turn all -0.0 to 0.0 - pruned[pruned == -0.0] = 0.0 - - return pruned.reshape(original_shape) - - -def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - b = prune_to_2_4(b.t()).t() - - if dtype == torch.int8: - a, b = to_int8(a), to_int8(b) - elif dtype == torch.float8_e4m3fn: - a, b = to_fp8(a), to_fp8(b) - elif dtype == torch.float16: - a, b = to_fp16(a), to_fp16(b) - elif dtype == torch.bfloat16: - a, b = to_bf16(a), to_bf16(b) - else: - raise ValueError("unsupported dtype") - - b_compressed, e = ops.cutlass_sparse_compress(b.t()) - - # Compressed B, Metadata, Original A, B - return b_compressed, e, a, b - - -def baseline_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = (scale_a * (scale_b * (torch.mm( - a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) - if bias is not None: - output = output + bias - - return output - - -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") -# Test working with a subset of A and B for sparse matmul -def test_cutlass_sparse_subset(): - - big_m = 1024 - m, n, k = 512, 512, 512 - - # Create tensors - b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, - big_m, n, k) - a = whole_a[0:m, 0:k] - scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - - torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 8011398551b9d..fb2c9f5d30583 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -5,7 +5,7 @@ import unittest from numbers import Number from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, - Union) + Type, Union) import pytest import torch @@ -1100,3 +1100,28 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, kwargs, test_utils=test_utils, raise_exception=raise_exception) if cond else {} + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = (scale_a * (scale_b * (torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) + if bias is not None: + output = output + bias + + return output diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 0cd86cef0a475..bf0d454ad511c 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -313,8 +313,10 @@ def check_model(model): assert output +@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") @pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") + reason="2of4 Sparse is not yet supported on this GPU type." + ) @pytest.mark.parametrize( "args_2of4", [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b2fc2360f47f1..dd2dd02eaf723 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -9,6 +9,7 @@ QuantizationType) from pydantic import BaseModel +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -27,6 +28,8 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform +logger = init_logger(__name__) + __all__ = ["CompressedTensorsLinearMethod"] SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" @@ -79,6 +82,8 @@ def get_quant_method( return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) + if scheme is None: + return UnquantizedLinearMethod() layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): @@ -340,10 +345,10 @@ def _get_scheme_from_parts( raise NotImplementedError( "No compressed-tensors compatible scheme was found.") - def get_scheme( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None) -> "CompressedTensorsScheme": + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: """ compressed-tensors supports non uniform in the following way: @@ -353,10 +358,7 @@ def get_scheme( which can be a full layer_name, a regex for a layer_name, or an nn.Module name. - We first check whether a layer is in the ignore group and use - CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer - - We then detect whether a layer_name is found in any target and + Detect whether a layer_name is found in any target and use the quantization scheme corresponding to the matched target to select the CompressedTensorsScheme used for infernece. """ @@ -394,6 +396,13 @@ def get_scheme( if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, sparsity_scheme=sparsity_scheme): + # FIXME(tlrmchlsmth): layers using W16A16 CUTLASS 2:4 sparse kernels + # currently produce bad output in some cases + if weight_quant is None: + logger.warning_once( + "CompressedTensors24 scheme is disabled for the w16a16 " + "case. Falling back to UnquantizedLinearMethod") + return None # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel scheme = CompressedTensors24(quantized=weight_quant is not None