Skip to content

Commit

Permalink
[Kernel][Bugfix] Refactor and Fix CUTLASS 2:4 Sparse Kernels (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#13198)

Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth authored Feb 14, 2025
1 parent 2344192 commit c1e37bf
Show file tree
Hide file tree
Showing 16 changed files with 576 additions and 473 deletions.
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
# Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
Expand All @@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.7.0
GIT_PROGRESS TRUE

Expand All @@ -266,7 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
"csrc/cutlass_extensions/common.cpp")

set_gencode_flags_for_srcs(
Expand Down Expand Up @@ -359,8 +360,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down Expand Up @@ -476,7 +476,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand Down
69 changes: 68 additions & 1 deletion csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@ namespace vllm::c3x {

using namespace cute;

template <typename T>
struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};

template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;

public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;

template <typename... Args>
static ArgumentType prepare_args(Args... args) {
return {};
}
};

/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
Expand Down Expand Up @@ -174,6 +198,49 @@ struct ScaledEpilogueBias
}
};

/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;

using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};

/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
Expand Down Expand Up @@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken
}
};

}; // namespace vllm::c3x
}; // namespace vllm::c3x
3 changes: 1 addition & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a);
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
Expand Down
13 changes: 9 additions & 4 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ struct cutlass_3x_gemm {

using EVTCompute = typename Epilogue::EVTCompute;

// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute>::CollectiveOp;
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;

static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
Expand All @@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
Expand Down
11 changes: 8 additions & 3 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,19 @@ struct cutlass_2x_gemm {

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;

// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;

// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
Expand Down
165 changes: 0 additions & 165 deletions csrc/sparse/cutlass/sparse_compressor_c3x.cu

This file was deleted.

Loading

0 comments on commit c1e37bf

Please sign in to comment.