diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index ca9cf15780e25..d2ae926daa7c0 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -30,6 +30,12 @@ docker exec cpu-test bash -c " --ignore=tests/models/test_jamba.py \ --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported +# Run compressed-tensor test +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" + # online inference docker exec cpu-test bash -c " export VLLM_CPU_KVCACHE_SPACE=10 diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 2b60835255cb4..34b4c95e34ffc 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,6 +2,10 @@ FROM ubuntu:22.04 AS cpu-test-1 +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + RUN --mount=type=cache,target=/var/cache/apt \ apt-get update -y \ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ @@ -26,6 +30,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --upgrade pip && \ pip install -r requirements-build.txt +# install oneDNN +RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git + +RUN --mount=type=cache,target=/root/.cache/ccache \ + cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_DOC=OFF \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_BUILD_GRAPH=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + -DONEDNN_ENABLE_PRIMITIVE=MATMUL && \ + cmake --build ./oneDNN/build --target install --config Release + FROM cpu-test-1 AS build WORKDIR /workspace/vllm @@ -41,7 +58,6 @@ COPY ./ ./ ARG VLLM_CPU_DISABLE_AVX512 ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} -ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 3ba3a2b6a93cd..8470e9ea9ebd9 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,4 +1,5 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_STANDARD 17) # # Define environment variables for special configurations @@ -83,12 +84,7 @@ endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") -list(APPEND LIBS "numa") - - -# -# Define extension targets -# +list(APPEND LIBS dnnl numa) # # _C extension @@ -102,6 +98,16 @@ set(VLLM_EXT_SRC "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") +if (AVX512_FOUND AND NOT AVX512_DISABLED) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() + +# +# Define extension targets +# + define_gpu_extension_target( _C DESTINATION vllm diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index f50620a5287d4..5b1d3d6442b2b 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -24,8 +24,8 @@ namespace vec_op { #define CPU_KERNEL_GUARD_OUT(NAME) #else #define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); +#define CPU_KERNEL_GUARD_OUT(NAME) #endif #define FORCE_INLINE __attribute__((always_inline)) inline @@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16 &); void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; #ifdef __AVX512F__ @@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(_mm512_max_ps(reg, b.reg)); + } + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 abs() const { + return FP32Vec16(_mm512_abs_ps(reg)); + } + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + float reduce_max() const { return _mm512_reduce_max_ps(reg); } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); @@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec { } void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + + void save(float* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_ps(ptr, mask, reg); + } }; #else struct FP32Vec16 : public Vec { @@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec { }; #endif +#ifdef __AVX512F__ +struct INT8Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m128i reg; + int8_t values[VEC_ELEM_NUM]; + }; + + __m128i reg; + + explicit INT8Vec16(const FP32Vec16& vec) : reg( + _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + ) {} + + void save(int8_t* ptr) const { + _mm_storeu_epi8(ptr, reg); + } + + void save(int8_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm_mask_storeu_epi8(ptr, mask, reg); + } +}; +#endif + template struct VecType { using vec_type = void; }; template using vec_t = typename VecType::vec_type; diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp new file mode 100644 index 0000000000000..024ad4ae43da8 --- /dev/null +++ b/csrc/cpu/dnnl_helper.hpp @@ -0,0 +1,168 @@ +#ifndef DNNL_HELPER_HPP +#define DNNL_HELPER_HPP + +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace { +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} +}; // namespace + +template +class DNNLPrimitiveHelper { + public: + // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) + // A: [M, K], row-major + // B: [K, N], column-major + // C: [M, N], row-major + // bias: [N], row-major, optional + // a_scales: [MS] + // b_scales: [NS] + // Note: Due to the limitation of oneDNN + // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is + // not supported. + template + static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, + const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, const float* a_scales, + const float* b_scales, dnnl_dim_t MS, + dnnl_dim_t NS) { + auto&& OutputType = get_dnnl_type(); + auto&& BiasType = get_dnnl_type(); + + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); + dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); + + dnnl::primitive_attr attr; + if constexpr (!InputNoScale) { + if (MS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_SRC, 0); + } else { + // per-token + TORCH_CHECK(false, "per-token quantization is unsupported."); + } + } + + if (NS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else { + // per-channel + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + dnnl::matmul::primitive_desc matmul_pd; + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + bias_md, c_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + c_md, attr); + } + dnnl::matmul matmul(matmul_pd); + + auto& engine = default_engine(); + + dnnl::memory a_m(a_md, engine, (void*)a); + dnnl::memory b_m(b_md, engine, (void*)b); + dnnl::memory c_m(c_md, engine, (void*)c); + dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)a_scales); + dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)b_scales); + + auto& stream = default_stream(); + if constexpr (InputNoScale) { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } else { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } + stream.wait(); + } + + private: + static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; + } + + static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; + } +}; + +#endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp new file mode 100644 index 0000000000000..0cfc19097fded --- /dev/null +++ b/csrc/cpu/quant.cpp @@ -0,0 +1,294 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.hpp" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512F__ +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_int8.save(output + i * hidden_size + j); + } else { + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_abs(0.0); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + max_abs = max_abs.max(elems_fp32.abs()); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + max_abs = max_abs.max(elems_fp32.abs()); + } else { + max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j); + } + } + + float scale_val = max_abs.reduce_max() / 127.0f; + scale[i] = scale_val; + const cvt_vec_t inv_scale(1.0 / scale_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + vec_op::INT8Vec16 elems_int8(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_int8.save(output + i * hidden_size + j); + } else { + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } + } +} + +template +void dynamic_output_scale_impl(const float* input, scalar_t* output, + const float* scale, const scalar_t* bias, + const int num_tokens, const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(scale[i]); + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + + if (j + vec_elem_num == hidden_size) { + elems_out.save(output + i * hidden_size + j); + } else { + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +#else +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_output_scale_impl() { + TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.") +} +#endif +} // namespace + +void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] { + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), (void*)(0), a.size(0), b.size(1), + a.size(1), (float*)(0), b_scales.data_ptr(), 0, + b_scales.numel()); + if (bias.has_value()) { + dynamic_output_scale_impl( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), bias->data_ptr(), c.size(0), + c.size(1)); + } else { + dynamic_output_scale_impl( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), (scalar_t*)(0), c.size(0), c.size(1)); + } + } else { + // per-tensor + if (bias.has_value()) { + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + bias->data_ptr(), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } else { + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + (void*)(0), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + const torch::Tensor& scale) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), num_tokens, hidden_size); + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale // [..., 1] +) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), num_tokens, hidden_size); + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index cf7d977da7c1c..9d6b1962b828c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -4,7 +4,12 @@ #include -void init_cpu_threads_env(const std::string& cpu_ids); +std::string init_cpu_threads_env(const std::string& cpu_ids); + +void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const c10::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -84,6 +89,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); + + // Quantization +#ifdef __AVX512F__ + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { @@ -111,7 +138,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { // CPU utils - utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); + utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5782580baa861..1138a55df2f05 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -5,7 +5,7 @@ #include "cpu_types.hpp" -void init_cpu_threads_env(const std::string& cpu_ids) { +std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; @@ -51,15 +51,40 @@ void init_cpu_threads_env(const std::string& cpu_ids) { torch::set_num_threads((int)omp_cpu_ids.size()); TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); + + std::vector> thread_core_mapping; + thread_core_mapping.reserve(omp_cpu_ids.size()); + omp_lock_t writelock; + omp_init_lock(&writelock); + #pragma omp parallel for schedule(static, 1) for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { - cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); - size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); - CPU_ZERO_S(size, mask); - CPU_SET_S(omp_cpu_ids[i], size, mask); - sched_setaffinity(0, sizeof(cpu_set_t), mask); - CPU_FREE(mask); + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(omp_cpu_ids[i], &mask); + int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); + if (ret == -1) { + TORCH_CHECK(false, + "sched_setaffinity failed. errno: " + std::to_string(errno)); + } + + omp_set_lock(&writelock); + thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + omp_unset_lock(&writelock); } + omp_destroy_lock(&writelock); + numa_free_nodemask(omp_cpu_mask); + + std::stringstream ss; + ss << "OMP threads binding of Process " << getpid() << ":\n"; + std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), + [](auto&& a, auto&& b) { return a.second < b.second; }); + for (auto&& item : thread_core_mapping) { + ss << "\t" + << "OMP tid: " << item.first << ", core " << item.second << "\n"; + } + + return ss.str(); } diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7dd20636c892f..627b2abaabcf9 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -56,7 +56,7 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): assert qkv_proj.weight_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32 - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy(["Hello my name is"], max_tokens=20) assert output @@ -85,7 +85,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy(["Hello my name is"], max_tokens=20) assert output diff --git a/vllm/config.py b/vllm/config.py index b3e91701c60a4..26e4b169587e1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -876,7 +876,8 @@ def __init__( from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray_is_available() - if cuda_device_count_stateless() < self.world_size: + if (torch.cuda.is_available() + and cuda_device_count_stateless() < self.world_size): if not ray_found: raise ValueError("Unable to load Ray which is " "required for multi-node inference, " diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index ec9b24ce1318f..7380b73ad6548 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -5,7 +5,8 @@ import torch import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) @@ -60,6 +61,8 @@ def _init_executor(self) -> None: self.cache_config = _verify_and_get_cache_config(self.cache_config) self.scheduler_config = _verify_and_get_scheduler_config( self.scheduler_config) + self.parallel_config = _verify_and_get_parallel_config( + self.parallel_config) # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address @@ -359,6 +362,16 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: return config +def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: + if (config.distributed_executor_backend is not None + and config.distributed_executor_backend != "mp"): + logger.warning( + "%s is not supported on CPU, fallback to mp distributed executor " + "backend.", config.distributed_executor_backend) + config.distributed_executor_backend = "mp" + return config + + def _driver_method_invoker(driver, method: str, *args, **kwargs): return getattr(driver, method)(*args, **kwargs) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1170d55f31993..b5b2570966600 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -116,15 +116,19 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - supported = capability >= min_capability - if error and not supported: - raise RuntimeError( - "Quantization scheme is not supported for ", - f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") - return supported + capability = current_platform.get_device_capability() # type: ignore + + if capability is not None: + capability = capability[0] * 10 + capability[1] + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f59eb805ea907..ac869e56ce198 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -95,8 +95,9 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - if not current_platform.is_tpu(): - capability = current_platform.get_device_capability() + capability = current_platform.get_device_capability() # type: ignore + + if capability is not None: capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): raise ValueError( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index aedf3c3a950ee..a483614d067e9 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -42,6 +42,13 @@ except Exception: pass +is_cpu = False +try: + from importlib.metadata import version + is_cpu = "cpu" in version("vllm") +except Exception: + pass + if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first @@ -53,6 +60,9 @@ elif is_rocm: from .rocm import RocmPlatform current_platform = RocmPlatform() +elif is_cpu: + from .cpu import CpuPlatform + current_platform = CpuPlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py new file mode 100644 index 0000000000000..4736e898b6a52 --- /dev/null +++ b/vllm/platforms/cpu.py @@ -0,0 +1,15 @@ +import torch + +from .interface import Platform, PlatformEnum + + +class CpuPlatform(Platform): + _enum = PlatformEnum.CPU + + @staticmethod + def get_device_name(device_id: int = 0) -> str: + return "cpu" + + @staticmethod + def inference_mode(): + return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 25b6f26676ef0..676f4c9fccf5a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,5 @@ import enum -from typing import Tuple +from typing import Optional, Tuple import torch @@ -8,6 +8,7 @@ class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() + CPU = enum.auto() UNSPECIFIED = enum.auto() @@ -23,9 +24,12 @@ def is_rocm(self) -> bool: def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - raise NotImplementedError + def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: + return None @staticmethod def get_device_name(device_id: int = 0) -> str: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 5e32bee1c5511..393fc230da0b9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from .interface import Platform, PlatformEnum @@ -8,10 +6,6 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU - @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - raise RuntimeError("TPU does not have device capability.") - @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 52d1806018f51..5e36fba6ccdea 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -207,7 +207,8 @@ def stop_profile(self): def init_device(self) -> None: if self.local_omp_cpuid != "all": - torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + logger.info(ret) self.init_distributed_environment() # Set random seed.