Skip to content

Commit

Permalink
Segregate AMD-specific tuning with is_hip checks
Browse files Browse the repository at this point in the history
Make the tuning on the knobs only applicable on AMD ROCm platform via is_hip checks.
  • Loading branch information
whchung committed Feb 1, 2025
1 parent eb34954 commit 5c095e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
13 changes: 10 additions & 3 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def _decode_att_m_fwd(
sm_scale,
logit_cap,
):
BLOCK = 8
BLOCK = 64
# [TODO] work around SGPR limit on MI3xx
if is_hip_:
BLOCK = 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
Expand All @@ -193,7 +196,9 @@ def _decode_att_m_fwd(
if kv_group_num == 1:
num_warps = 4
else:
num_warps = 1
num_warps = 2
if is_hip_:
num_warps = 1

BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
Expand Down Expand Up @@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd(
)

extra_kargs = {}
num_stages = 2
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
num_stages = 1

_fwd_grouped_kernel_stage1[grid](
q,
Expand Down Expand Up @@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd(
NUM_KV_SPLITS=NUM_KV_SPLITS,
logit_cap=logit_cap,
num_warps=4,
num_stages=1,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ServerArgs:
enable_nan_detection: bool = False
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 16
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
Expand Down Expand Up @@ -273,6 +273,10 @@ def __post_init__(self):
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"

# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
Expand Down

0 comments on commit 5c095e5

Please sign in to comment.