Skip to content

Commit

Permalink
[Feat] Add window attention for gemma-2 (#1056)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Aug 14, 2024
1 parent ad3e4f1 commit 0909bb0
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 126 deletions.
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BenchArgs:
run_name: str = "before"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (4,)
output_len: Tuple[int] = (16,)
result_filename: str = ""
correctness_test: bool = False
# This is only used for correctness test
Expand Down
59 changes: 40 additions & 19 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
scaling: float,
num_kv_heads: int,
layer_id: int,
sliding_window_size: int = -1,
logit_cap: int = -1,
v_head_dim: int = -1,
):
Expand All @@ -46,6 +47,7 @@ def __init__(
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.sliding_window_size = sliding_window_size

if (
not global_server_args_dict.get("disable_flashinfer", False)
Expand Down Expand Up @@ -113,39 +115,51 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
return o

def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1:
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_ragged, list):
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]

if not input_metadata.flashinfer_use_ragged:
self.store_kv_cache(k, v, input_metadata)

o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=True,
sm_scale=self.scaling,
window_left=self.sliding_window_size,
logits_soft_cap=self.logit_cap,
)
else:
o1, s1 = (
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
o1, s1 = prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
causal=True,
sm_scale=self.scaling,
window_left=self.sliding_window_size,
logits_soft_cap=self.logit_cap,
)

if input_metadata.extend_no_prefix:
o = o1
else:
o2, s2 = (
input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=False,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)
# TODO window attention + radix attention will come up in next PR
assert self.sliding_window_size == -1
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=False,
sm_scale=self.scaling,
logits_soft_cap=self.logit_cap,
)

o, _ = merge_state(o1, s1, o2, s2)
Expand All @@ -158,9 +172,16 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
return o.view(-1, self.tp_q_head_num * self.head_dim)

def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
decode_wrapper = input_metadata.flashinfer_decode_wrapper
if self.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]

self.store_kv_cache(k, v, input_metadata)

o = input_metadata.flashinfer_decode_wrapper.forward(
o = decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
sm_scale=self.scaling,
Expand Down
203 changes: 145 additions & 58 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -154,6 +154,7 @@ def from_schedule_batch(
model_runner: "ModelRunner",
batch: ScheduleBatch,
forward_mode: ForwardMode,
sliding_window_size: Optional[int] = None,
):
ret = cls(
forward_mode=forward_mode,
Expand Down Expand Up @@ -197,7 +198,7 @@ def from_schedule_batch(
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
)

return ret
Expand All @@ -216,7 +217,11 @@ def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
self.triton_max_extend_len = int(torch.max(extend_seq_lens))

def init_flashinfer_handlers(
self, model_runner, prefix_lens, flashinfer_use_ragged
self,
model_runner,
prefix_lens,
flashinfer_use_ragged,
sliding_window_size=None,
):
update_flashinfer_indices(
self.forward_mode,
Expand All @@ -225,6 +230,7 @@ def init_flashinfer_handlers(
self.seq_lens,
prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged,
sliding_window_size=sliding_window_size,
)

(
Expand All @@ -248,72 +254,153 @@ def update_flashinfer_indices(
prefix_lens,
flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False,
sliding_window_size=None,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices)

if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")

if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper

flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

if sliding_window_size is None:
if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")

if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper

flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)

# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
for wrapper_id in range(2):
if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens

# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
paged_kernel_lens = torch.minimum(
paged_kernel_lens, torch.tensor(sliding_window_size)
)
kv_start_idx = seq_lens - paged_kernel_lens
else:
kv_start_idx = torch.zeros(batch_size, dtype=torch.int32, device="cuda")

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i],
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
]
for i in range(batch_size)
],
dim=0,
).contiguous()

if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper

flashinfer_decode_wrapper[wrapper_id].end_forward()
flashinfer_decode_wrapper[wrapper_id].begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged[
wrapper_id
].end_forward()
model_runner.flashinfer_prefill_wrapper_ragged[
wrapper_id
].begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)

# cached part
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
Loading

0 comments on commit 0909bb0

Please sign in to comment.