diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 2b4871af98c..512900bd301 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -181,6 +181,9 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -194,6 +197,8 @@ def _decode_att_m_fwd( num_warps = 4 else: num_warps = 2 + if is_hip_: + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, @@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9340e47764..8c5ad0b96ec 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -273,6 +273,10 @@ def __post_init__(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + # AMD-specific Triton attention KV splits default number + if is_hip(): + self.triton_attention_num_kv_splits = 16 + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args