Skip to content

Commit

Permalink
[kernel] Fix position ids in rope (#3173)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Jan 27, 2025
1 parent 52c03f1 commit 514f37c
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post19"
version = "0.0.2.post20"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/src/sgl-kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def apply_rope_with_cos_sin_cache_inplace(
raise ValueError("cos_sin_cache should be float32")

with query.device as device:
pos_ids = pos_ids.int()
positions = positions.int()
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
Expand Down
4 changes: 4 additions & 0 deletions sgl-kernel/tests/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,7 @@ def test_correctness(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
pytest.main([__file__])
2 changes: 1 addition & 1 deletion sgl-kernel/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2.post19"
__version__ = "0.0.2.post20"

0 comments on commit 514f37c

Please sign in to comment.