Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] LoRA - Add triton kernels for V1 #13096

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 140 additions & 18 deletions benchmarks/kernels/benchmark_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.v1.ops.triton_ops.v1_expand import v1_expand
from vllm.lora.v1.ops.triton_ops.v1_shrink import v1_shrink
from vllm.lora.v1.punica_wrapper.punica_gpu_v1 import V1KernelMeta
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
Expand Down Expand Up @@ -172,6 +175,8 @@ class OpType(Enum):
SGMV_EXPAND = auto()
BGMV_EXPAND = auto()
BGMV_EXPAND_SLICE = auto()
V1_SHRINK = auto()
V1_EXPAND = auto()

@staticmethod
def from_str(s: str) -> "OpType":
Expand All @@ -185,28 +190,43 @@ def from_str(s: str) -> "OpType":
return OpType.BGMV_EXPAND
if s.lower() == "bgmv_expand_slice":
return OpType.BGMV_EXPAND_SLICE
if s.lower() == "v1_shrink":
return OpType.V1_SHRINK
if s.lower() == "v1_expand":
return OpType.V1_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType")

def is_shrink_fn(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
return self in [
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
]

def is_expand_fn(self) -> bool:
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
return self in [
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
]

def is_prefill_op(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
return self in [
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
OpType.V1_EXPAND
]

def is_decode_op(self) -> bool:
return self in [
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
OpType.V1_SHRINK, OpType.V1_EXPAND
]

def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE]

def num_slices(self) -> List[int]:
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
# SGMV kernels supports slices
if self in [
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
OpType.V1_EXPAND
]:
# SGMV kernels and v1 kernels supports slices
return [1, 2, 3]
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
return [1]
Expand Down Expand Up @@ -251,11 +271,13 @@ def matmul_shapes(
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)

b_shape = (num_loras, n, k) # col-major
if self == OpType.SGMV_SHRINK:
# SGMV shrink supports num_slices inherently in the kernel
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
# SGMV shrink and V1 shrink kernels support num_slices inherently
# in the kernel.
return ((m, k), b_shape, (num_slices, m, n))
if self == OpType.SGMV_EXPAND:
# SGMV expand supports num_slices inherently in the kernel
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
# SGMV expand and V1 expand kernels support num_slices inherently
# in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self == OpType.BGMV_SHRINK:
return ((m, k), b_shape, (m, n))
Expand All @@ -282,25 +304,30 @@ def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
return bgmv_expand
if self == OpType.BGMV_EXPAND_SLICE:
return emulate_bgmv_expand_slice
if self == OpType.V1_SHRINK:
return v1_shrink
if self == OpType.V1_EXPAND:
return v1_expand

raise ValueError(f"Unrecognized optype {self}")

def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
lora_weights: List[torch.Tensor],
**kwargs) -> Callable:
"""Each benchmark operation expected the input, lora_weights and outputs
"""Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights)
if self == OpType.SGMV_SHRINK:
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :],
input=input,
lora_weights=lora_weights[slice_idx],
**kwargs)
if self == OpType.SGMV_EXPAND:
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
Expand All @@ -309,19 +336,19 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
if self == OpType.BGMV_SHRINK:
elif self == OpType.BGMV_SHRINK:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input,
lora_weights=lora_weights[0],
**kwargs)
if self == OpType.BGMV_EXPAND:
elif self == OpType.BGMV_EXPAND:
assert num_slices == 1
ref_group_gemm(ref_out=output,
input=input.clone().to(dtype=w_dtype),
lora_weights=lora_weights[0],
**kwargs)
if self == OpType.BGMV_EXPAND_SLICE:
elif self == OpType.BGMV_EXPAND_SLICE:
hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size
Expand All @@ -330,7 +357,8 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx],
**kwargs)
raise ValueError(f"Unrecognized optype {self}")
else:
raise ValueError(f"Unrecognized optype {self}")


@dataclass
Expand Down Expand Up @@ -391,6 +419,8 @@ class BenchmarkTensors:
seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor
token_lora_mapping: torch.Tensor
# v1 kernel metadata
v1_kernel_meta: Optional[V1KernelMeta] = None

def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x"
Expand Down Expand Up @@ -433,10 +463,19 @@ def make(ctx: BenchmarkContext,
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu")

v1_kernel_meta = None
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
v1_kernel_meta = V1KernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
v1_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)

return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
seq_len_tensor, seq_start_loc_tensor,
prompt_lora_indices_tensor,
token_lora_indices_tensor)
token_lora_indices_tensor, v1_kernel_meta)

def sanity_check(self) -> None:
"""
Expand Down Expand Up @@ -469,6 +508,13 @@ def to_device(tensor: torch.Tensor):
for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])

# v1 meta
if self.v1_kernel_meta:
for field_name in V1KernelMeta.__dataclass_fields__:
field = getattr(self.v1_kernel_meta, field_name)
assert isinstance(field, torch.Tensor)
setattr(self.v1_kernel_meta, field_name, to_device(field))

def metadata(self) -> Tuple[int, int, int]:
"""
Return num_seqs, num_tokens and max_seq_len
Expand Down Expand Up @@ -668,6 +714,78 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
})
return {'kwargs_list': kwargs_list}

def as_v1_shrink_kwargs(self) -> Dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)

_, num_tokens, _, num_slices = self.metadata()

# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert len(o_shape) == 3
assert o_shape == (num_slices, num_tokens, lora_rank)

return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'scaling': 1.0,
}

def as_v1_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)

_, num_tokens, _, num_slices = self.metadata()

# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)

return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'offset_start': 0,
'add_inputs': add_inputs,
}

def bench_fn_kwargs(self,
op_type: OpType,
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
Expand All @@ -686,6 +804,10 @@ def bench_fn_kwargs(self,
return self.as_bgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_EXPAND_SLICE:
return self.as_bgmv_expand_slice_kwargs(add_inputs)
if op_type == OpType.V1_SHRINK:
return self.as_v1_shrink_kwargs()
if op_type == OpType.V1_EXPAND:
return self.as_v1_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}")

def test_correctness(self, op_type: OpType,
Expand Down
Loading
Loading