diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 2b4871af98c..86bc15c1cb5 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -180,7 +180,7 @@ def _decode_att_m_fwd( sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -193,7 +193,7 @@ def _decode_att_m_fwd( if kv_group_num == 1: num_warps = 4 else: - num_warps = 2 + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9340e47764..6176f53d937 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -154,7 +154,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False - triton_attention_num_kv_splits: int = 8 + triton_attention_num_kv_splits: int = 16 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False