Skip to content

Commit

Permalink
wip fix tensor on wrong device
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Feb 11, 2025
1 parent 5daa921 commit 7bffc5c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 45 deletions.
11 changes: 11 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,17 @@ void gather_cache(
"seq_starts must be int32");
}

TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}

int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,

opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size),
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)

Expand Down
9 changes: 6 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,12 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)


def gather_cache(src_cache: torch.Tensor, dst: torch.Tensor,
block_table: torch.Tensor, cu_seq_lens: torch.Tensor,
batch_size: int, seq_starts: Optional[torch.Tensor]) -> None:
def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)

Expand Down
80 changes: 39 additions & 41 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ class TritonMLAMetadata(MLACommonMetadata):

# For chunked prefill
chunk_prefill_workspace_size: Optional[int] = None
chunk_cu_seq_lens: Optional[List[torch.Tensor]] = None
chunk_seq_starts: Optional[List[torch.Tensor]] = None
chunk_cu_seq_lens: Optional[torch.Tensor] = None
chunk_seq_starts: Optional[torch.Tensor] = None
chunk_iter_toks: Optional[List[int]] = None
chunk_max_seq_lens: Optional[List[int]] = None

Expand Down Expand Up @@ -665,45 +665,48 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.multimodal_placeholder_maps.items()
}

chunk_prefill_workspace_size = \
self.runner.model_config.max_model_len * 4
chunk_cu_seq_lens = None
chunk_seq_starts = None
chunk_iter_toks = None
chunk_max_seq_lens = None

print("seq_start_loc_tensor", seq_start_loc_tensor.device)

chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled
if chunked_prefill_enabled:
chunk_prefill_workspace_size = \
self.runner.model_config.max_model_len * 4
if chunked_prefill_enabled and self.num_prefills > 0:
page_size = self.runner.block_size
seq_chunk_size = chunk_prefill_workspace_size // self.num_prefills
# align seq_chunk_size to page_size by rounding down
seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size)
num_chunks = (max_prefill_seq_len + seq_chunk_size -
1) // seq_chunk_size

chunk_cu_seq_lens = []
chunk_seq_starts = []
chunk_iter_toks = []
chunk_max_seq_lens = []

for chunk in range(num_chunks):
chunk_starts = chunk * seq_chunk_size
chunk_starts = torch.tensor([chunk_starts] * self.num_prefills,
dtype=torch.int32,
device=device)
chunk_ends = seq_lens_tensor.clamp(max=(chunk + 1) *
seq_chunk_size)
_chunk_cu_seq_lens = (chunk_ends - chunk_starts).clamp(
min=0).cumsum(dim=0).to(torch.int32)
chunk_iter_toks.append(_chunk_cu_seq_lens.sum())
chunk_max_seq_lens.append(_chunk_cu_seq_lens.max().item())

zero = torch.zeros(1, dtype=torch.int32, device=device)
_chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens],
dim=0)

chunk_cu_seq_lens.append(_chunk_cu_seq_lens.contiguous())
chunk_seq_starts.append(chunk_starts.contiguous())
# if `seq_chunk_size = 256`, `num_chunks = 3`, and
# `num_prefills = 4`, create a tensor that looks like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
chunk_seq_starts = \
torch.arange(num_chunks, device=device, dtype=torch.int32)\
.unsqueeze(1).expand(-1, self.num_prefills)\
* seq_chunk_size
chunk_ends = torch.min(seq_lens_tensor.unsqueeze(0),\
chunk_seq_starts + seq_chunk_size)
chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0)
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32)
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\
.unsqueeze(0)
print(self.num_prefills, zero, _chunk_cu_seq_lens)
chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1)
chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist()
chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist()

print(chunk_seq_starts)
print(chunk_ends)
#print(_chunk_cu_seq_lens)
print(chunk_cu_seq_lens)
print(chunk_max_seq_lens)
print(chunk_iter_toks)

return TritonMLAMetadata(
num_prefills=self.num_prefills,
Expand All @@ -726,6 +729,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
use_cuda_graph=use_captured_graph,
num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(),
# Chunked prefill data
chunk_prefill_workspace_size=chunk_prefill_workspace_size,
chunk_cu_seq_lens=chunk_cu_seq_lens,
chunk_seq_starts=chunk_seq_starts,
Expand Down Expand Up @@ -794,17 +798,11 @@ def _forward_prefill(
)

output = None
for chunk_cu_seq_lens, chunk_seq_starts, toks, max_seq_len in \
zip(
attn_metadata.chunk_cu_seq_lens, \
attn_metadata.chunk_seq_starts, \
attn_metadata.chunk_iter_toks, \
attn_metadata.chunk_max_seq_lens):

print("cu_seq_lens", chunk_cu_seq_lens)
print("seq_starts", chunk_seq_starts)
print("toks", toks)
print("max_seq_len", max_seq_len)
for i in range(len(prefill_metadata.chunk_iter_toks)):
chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i]
chunk_seq_starts = prefill_metadata.chunk_seq_starts[i]
toks = prefill_metadata.chunk_iter_toks[i]
max_seq_len = prefill_metadata.chunk_max_seq_lens[i]

ops.gather_cache(
src_cache=kv_c_and_k_pe_cache,
Expand Down Expand Up @@ -859,8 +857,8 @@ def _forward_prefill(
output=output,
prefix_output=output,
prefix_lse=attn_softmax_lse,
new_output=attn_output,
new_lse=attn_softmax_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)

output = output\
Expand Down

0 comments on commit 7bffc5c

Please sign in to comment.