From 514f37c32b5b8335dc73860f038dab0e53cdb9fb Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 27 Jan 2025 01:09:51 -0800 Subject: [PATCH] [kernel] Fix position ids in rope (#3173) --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 2 +- sgl-kernel/tests/test_rotary_embedding.py | 4 ++++ sgl-kernel/version.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index a85923a5a6f..8664fb09021 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post19" +version = "0.0.2.post20" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 3543d7423d1..2fa1d957980 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -51,7 +51,7 @@ def apply_rope_with_cos_sin_cache_inplace( raise ValueError("cos_sin_cache should be float32") with query.device as device: - pos_ids = pos_ids.int() + positions = positions.int() torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( q=query.view(query.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size), diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py index 901b692362d..b7a141404e6 100644 --- a/sgl-kernel/tests/test_rotary_embedding.py +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -196,3 +196,7 @@ def test_correctness( query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 3e080a673e1..45807e905cc 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.2.post19" +__version__ = "0.0.2.post20"