From c145acb888dd6b7af2a9c3e989cee9123f61060d Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sat, 1 Feb 2025 10:10:49 -0600 Subject: [PATCH 1/4] Tune paged attention parameters for AMD GPU. Changes: - num_kv_splits - BLOCK - num_warps --- .../srt/layers/attention/triton_ops/decode_attention.py | 4 ++-- python/sglang/srt/server_args.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 2b4871af98c..86bc15c1cb5 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -180,7 +180,7 @@ def _decode_att_m_fwd( sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -193,7 +193,7 @@ def _decode_att_m_fwd( if kv_group_num == 1: num_warps = 4 else: - num_warps = 2 + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9340e47764..6176f53d937 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -154,7 +154,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False - triton_attention_num_kv_splits: int = 8 + triton_attention_num_kv_splits: int = 16 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False From 9c5980e8e2ec1d2db28031853b802093ba51a643 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sat, 1 Feb 2025 11:24:34 -0600 Subject: [PATCH 2/4] Additional tuning for grouped page attention kernel. Changed: - waves_per_eu --- .../sglang/srt/layers/attention/triton_ops/decode_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 86bc15c1cb5..4ccf38b0588 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -436,7 +436,7 @@ def _decode_grouped_att_m_fwd( if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} _fwd_grouped_kernel_stage1[grid]( q, From eb34954056eb1ada21f3ebcf10c1a27c807dbf7f Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sat, 1 Feb 2025 11:30:04 -0600 Subject: [PATCH 3/4] Additional tuning for grouped paged attention kernel Changed: - num_stages --- .../sglang/srt/layers/attention/triton_ops/decode_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 4ccf38b0588..25818b41dfa 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -467,7 +467,7 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=1, Lk=Lk, Lv=Lv, **extra_kargs, From 5c095e5df1a91b3037746555c4981cff14f0d8a4 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sat, 1 Feb 2025 17:06:52 -0600 Subject: [PATCH 4/4] Segregate AMD-specific tuning with is_hip checks Make the tuning on the knobs only applicable on AMD ROCm platform via is_hip checks. --- .../layers/attention/triton_ops/decode_attention.py | 13 ++++++++++--- python/sglang/srt/server_args.py | 6 +++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 25818b41dfa..512900bd301 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -180,7 +180,10 @@ def _decode_att_m_fwd( sm_scale, logit_cap, ): - BLOCK = 8 + BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -193,7 +196,9 @@ def _decode_att_m_fwd( if kv_group_num == 1: num_warps = 4 else: - num_warps = 1 + num_warps = 2 + if is_hip_: + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, @@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=4, - num_stages=1, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6176f53d937..8c5ad0b96ec 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -154,7 +154,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False - triton_attention_num_kv_splits: int = 16 + triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False @@ -273,6 +273,10 @@ def __post_init__(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + # AMD-specific Triton attention KV splits default number + if is_hip(): + self.triton_attention_num_kv_splits = 16 + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args