Skip to content

Commit

Permalink
Update Triton extend backend interface (#3309)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Feb 5, 2025
1 parent 7aad8d1 commit de55333
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
extend_attention_fwd,
flash_decode_attention_fwd,
flash_decode_sparse_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
)

super().__init__()

Expand Down
68 changes: 52 additions & 16 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, model_runner: ModelRunner):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)

self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
Expand All @@ -54,6 +57,9 @@ def __init__(self, model_runner: ModelRunner):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""

bs = forward_batch.batch_size
kv_indptr = self.kv_indptr

if forward_batch.forward_mode.is_decode():
attn_logits = torch.empty(
(
Expand All @@ -68,31 +74,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

max_extend_len = None

kv_indptr = self.kv_indptr
bs = len(forward_batch.req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
forward_batch.req_to_token_pool.req_to_token,
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
forward_batch.req_to_token_pool.req_to_token.stride(0),
self.req_to_token.stride(0),
)

qo_indptr = None
custom_mask = None
else:
kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32,
device=self.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.extend_prefix_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)

qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None

attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()

kv_indptr = None
kv_indices = None

self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
self.forward_metadata = (
attn_logits,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
)

def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
Expand Down Expand Up @@ -144,6 +178,8 @@ def init_forward_metadata_capture_cuda_graph(
None,
kv_indptr,
kv_indices,
None,
None,
)

def init_forward_metadata_replay_cuda_graph(
Expand Down Expand Up @@ -197,19 +233,19 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

_, max_extend_len, _, _ = self.forward_metadata
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
self.forward_metadata
)
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
qo_indptr,
kv_indptr,
kv_indices,
max_extend_len,
layer.scaling,
layer.logit_cap,
Expand All @@ -235,7 +271,7 @@ def forward_decode(
else:
o = torch.empty_like(q)

attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
Expand Down
Loading

0 comments on commit de55333

Please sign in to comment.