Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Working Grouped gemm with group ID #48

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/grouped_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
243 changes: 243 additions & 0 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import List, Tuple

import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe,
fused_experts,
fused_topk)
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = [
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
]
DEFAULT_BATCH_SIZES = [16, 32, 64, 128, 256, 512]

NUM_GROUPS_OPTS = [8] #[8, 64]
PER_ACT_TOKEN_OPTS = [False] #[False, True]
PER_OUT_CH_OPTS = [False] #[False, True]
TOPKS = [2, 6]


def run_from_graph(a_q: torch.Tensor, a_scale: torch.Tensor,
w1_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int,
n: int, k: int, e: int):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
return cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, e)


def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def bench_run(results: List[benchmark.Measurement], model: str,
num_experts: int, topk: int, per_act_token: bool,
per_out_ch: bool, mkn: Tuple[int, int, int]):
label = "Quant Matmul"

sub_label = ("{}, num_experts={}, per_act_token={} per_out_ch={}, "
"MKN=({})".format(model, num_experts, per_act_token,
per_out_ch, mkn))

print(f"Testing: {sub_label}")

(m, k, n) = mkn

dtype = torch.half

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10

a_q, a_scale = ops.scaled_fp8_quant(a)

w1_q = torch.empty((num_experts, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((num_experts, k, n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1),
device="cuda",
dtype=torch.float32)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)

def replay_graph(graph):
graph.replay()
torch.cuda.synchronize()

stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
run_from_graph(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale,
topk_weights, topk_ids, m, n, k, num_experts)
torch.cuda.synchronize()

globals = {
# Baseline params
"a": a,
"w1": w1,
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_q": a_q,
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"m": m,
"n": n,
"k": k,
"num_experts": num_experts,
# Cutlass cuda graph params
"graph": graph,
# Gen params
"topk_weights": topk_weights,
"topk_ids": topk_ids,
# Kernels
"fused_experts": fused_experts,
"cutlass_moe": cutlass_moe,
"replay_graph": replay_graph,
}

min_run_time = 1
num_warmup = 5

# Warmup
for _ in range(num_warmup):
# fused_experts(a, w1, w2, topk_weights, topk_ids)
fused_experts(a,
w1_q_notransp,
w2_q_notransp,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale)

results.append(
benchmark.Timer(
# stmt="fused_experts(a, w1, w2, topk_weights, topk_ids)",
stmt=
"fused_experts(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_scale)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="triton_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
for _ in range(num_warmup):
cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
topk_ids, m, n, k, num_experts)

results.append(
benchmark.Timer(
stmt=
"cutlass_moe(a_q, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, m, n, k, num_experts)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe",
).blocked_autorange(min_run_time=min_run_time))

# Warmup
for _ in range(num_warmup):
replay_graph(graph)

results.append(
benchmark.Timer(
stmt="replay_graph(graph)",
globals=globals,
label=label,
sub_label=sub_label,
description="grouped_gemm_moe_cuda_graphs",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")

results: List[benchmark.Measurement] = []

for model in args.models:
for layer in WEIGHT_SHAPES_MOE[model]:
num_experts = layer[0]
size_k = layer[1]
size_n = layer[2]

if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue

if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue

for per_act_token in PER_ACT_TOKEN_OPTS:
for per_out_ch in PER_OUT_CH_OPTS:
for topk in TOPKS:
for size_m in DEFAULT_BATCH_SIZES:
mkn = (size_m, size_k, size_n)
bench_run(results, model, num_experts, topk,
per_act_token, per_out_ch, mkn)

compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES_MOE.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
parser.add_argument("--limit-per-act-token",
nargs="+",
type=int,
default=[])
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])

args = parser.parse_args()
main(args)
16 changes: 16 additions & 0 deletions benchmarks/kernels/benchmark_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,19 @@
[7168, 8192],
],
}

WEIGHT_SHAPES_MOE = {
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
[8, 4096, 28672],
[8, 14336, 4096],
],
"nm-testing/deepseekv2-lite": [
[64, 2048, 1408],
],
"ibm-granite/granite-3.0-1b-a400m": [
[32, 1024, 1024],
],
"ibm-granite/granite-3.0-3b-a800m": [
[40, 1024, 1536],
],
}
16 changes: 16 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// CUTLASS w8a8 grouped GEMM // TODO complete this
// ops.def(
// "cutlass_grouped_mm(Tensor! out, Tensor a, Tensor b, Tensor a_scales,
// " " Tensor b_scales, Tensor problem_sizes, " "
// Tensor out_offsets, Tensor a_offsets, " " Tensor
// b_offsets) -> ()");
// ops.impl("cutlass_grouped_mm", torch::kCUDA, &cutlass_grouped_mm);

// ops.def(
// "compute_expert_offsets(Tensor! trg_a_ptrs,"
// " Tensor! a, Tensor topk_ids,"
// " Tensor! expert_offsets, SymInt num_experts) ->
// ()");
// ops.impl("compute_expert_offsets", torch::kCUDA,
// &compute_expert_offsets);

// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
Expand Down
Loading