Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] how to solve illegal memory access in moe_align_block_size kernel optimization #3339

Closed
5 tasks done
BBuf opened this issue Feb 6, 2025 · 9 comments
Closed
5 tasks done

Comments

@BBuf
Copy link
Collaborator

BBuf commented Feb 6, 2025

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

As mentioned in lines of code, when attempting to optimize the most expensive write operation of sorted_token_ids in the moe_align_block_size of DeepSeek V3, using multiple Thread Blocks instead of a single Block triggers an illegal global memory write access . Even directly replacing these lines of code with the Triton Stage4 kernel here results in the same illegal global memory write access within the Triton kernel. This issue has persisted for nearly a month, and the cause has not yet been identified, so I am opening an issue to seek help.

Based on the kernel benchmark results from H200, the performance of the current CUDA kernel degrades to be slower than Triton when the number of tokens is >= 4096. On H100, the performance degrades to be slower than Triton when the number of tokens is >= 32768. The reason is the lines of code I pointed out earlier. if multiple Thread Blocks are used, the fastest execution speed can be achieved in all token scenarios.

Additionally, the original implementation of the VLLM kernel exhibits similar behavior, even though it uses shared memory for counting when the number of tokens is <= 65536. According to my benchmark results, its performance is still significantly slower than both the sgl-kernel CUDA operators and the Triton version, so I will not discuss the moe_align_block_size kernel in VLLM.

The code below is the result of modifying the performance-critical lines to use element-wise multiple Thread Blocks.

// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
  // If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
  // kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
  // illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
  // results in the same issue, and a correct solution has not yet been found.
  // for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
  //  int32_t expert_id = topk_ids[i];
  //  int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
  //  sorted_token_ids[rank_post_pad] = i;
  // }

template <typename scalar_t>
__global__ void sort_token_ids_kernel(scalar_t* __restrict__ topk_ids, 
                                    int32_t* sorted_token_ids,
                                    int32_t* token_cnts_buffer,
                                    const int32_t* cumsum,
                                    int32_t num_experts,
                                    size_t numel,
                                    int32_t tokens_per_block) {
    const size_t start_idx = blockIdx.x * tokens_per_block;
    const size_t end_idx = min(start_idx + tokens_per_block, numel);
    
    const size_t off_t = blockIdx.x * num_experts;
    
    for (size_t i = start_idx + threadIdx.x; i < end_idx; i += blockDim.x) {
        int expert_id = topk_ids[i];
        int token_cnt = atomicAdd(&token_cnts_buffer[off_t + expert_id], 1);
        int rank_post_pad = token_cnt + cumsum[expert_id];
        sorted_token_ids[rank_post_pad] = i;
    }
}

const int tokens_per_block = CEILDIV(topk_ids.numel(), num_experts);
sort_kernel<<<num_experts, 256, 0, stream>>>(
    topk_ids.data_ptr<scalar_t>(),
    sorted_token_ids.data_ptr<int32_t>(),
    token_cnts_buffer.data_ptr<int32_t>(),
    cumsum_buffer.data_ptr<int32_t>(),
    num_experts,
    topk_ids.numel(),
    tokens_per_block);

The error message obtained using COMPUTE-SANITIZER is: Invalid global write of size 4 bytes. And it happend in line code sorted_token_ids[rank_post_pad] = i;.

Below are the benchmark results for the sgl-kernel moe_align_block_size and the Triton version on H100 and H200:

H100

python3 /opt/dlami/nvme/bbuf/sglang/benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py --save_path /opt/dlami/nvme/bbuf/configs
✅ CUDA and Triton implementations match
moe-align-block-size-performance:
     batch_size  seq_len         CUDA       Triton
0           1.0      1.0    24.224000    73.408000
1           1.0      2.0    24.192000    72.127998
2           1.0      4.0    24.256000    71.648002
3           1.0      8.0    24.192000    73.023997
4           1.0     16.0    24.224000    72.191998
5           1.0     32.0    24.192000    71.392000
6           1.0     64.0    24.288001    73.344000
7           1.0    128.0    24.672000    73.536001
8           1.0    256.0    24.800001    72.768003
9           1.0    512.0    25.152000    72.127998
10          1.0   1024.0    25.536001    74.239999
11          1.0   2048.0    26.176000    73.919997
12          1.0   4096.0    28.031999    72.319999
13          1.0   8192.0    32.687999    72.864003
14          1.0  16384.0    45.279998    73.760003
15          1.0  32768.0    **85.104004**    78.560002
16          2.0      1.0    24.192000    72.400004
17          2.0      2.0    24.192000    73.472001
18          2.0      4.0    24.256000    79.584002
19          2.0      8.0    24.288001    79.392001
20          2.0     16.0    24.256000    79.167999
21          2.0     32.0    24.320001    79.552002
22          2.0     64.0    24.704000    77.616006
23          2.0    128.0    24.831999    78.143999
24          2.0    256.0    25.119999    79.135999
25          2.0    512.0    25.536001    78.528002
26          2.0   1024.0    26.240001    79.183996
27          2.0   2048.0    28.031999    78.656003
28          2.0   4096.0    32.639999    78.863993
29          2.0   8192.0    45.184001    80.895998
30          2.0  16384.0    **85.280001**    79.903997
31          2.0  32768.0   **148.287997**   101.120003
32          4.0      1.0    24.288001    79.584002
33          4.0      2.0    24.288001    80.767997
34          4.0      4.0    24.320001    79.552002
35          4.0      8.0    24.288001    79.584002
36          4.0     16.0    24.320001    79.775997
37          4.0     32.0    24.672000    77.791996
38          4.0     64.0    24.800001    78.528002
39          4.0    128.0    25.184000    78.111999
40          4.0    256.0    25.664000    79.007998
41          4.0    512.0    26.272001    78.079998
42          4.0   1024.0    28.000001    79.360001
43          4.0   2048.0    32.559998    78.879997
44          4.0   4096.0    45.248002    77.904001
45          4.0   8192.0    **85.344002**    80.400005
46          4.0  16384.0   **148.256004**   101.120003
47          4.0  32768.0   **272.832006**   146.431997
48          8.0      1.0    23.584001    79.967998
49          8.0      2.0    23.664000    78.896001
50          8.0      4.0    23.568001    79.728007
51          8.0      8.0    23.552001    75.167999
52          8.0     16.0    23.615999    74.975997
53          8.0     32.0    23.680000    75.488001
54          8.0     64.0    24.000000    76.159999
55          8.0    128.0    24.512000    75.407997
56          8.0    256.0    25.664000    76.031998
57          8.0    512.0    27.807999    75.584002
58          8.0   1024.0    32.639999    75.231999
59          8.0   2048.0    44.960000    74.879996
60          8.0   4096.0    **85.472003**    77.951998
61          8.0   8192.0   **149.215996**    99.104002
62          8.0  16384.0   **273.791999**   146.752000
63          8.0  32768.0   **527.520001**   238.463998
64         16.0      1.0    23.552001    75.263999
65         16.0      2.0    23.552001    75.151995
66         16.0      4.0    23.600001    74.720003
67         16.0      8.0    23.712000    75.903997
68         16.0     16.0    23.744000    75.231999
69         16.0     32.0    24.064001    75.360000
70         16.0     64.0    24.351999    75.839996
71         16.0    128.0    25.632000    75.199999
72         16.0    256.0    27.775999    74.816003
73         16.0    512.0    32.671999    75.360000
74         16.0   1024.0    44.976000    75.520001
75         16.0   2048.0    **85.408002**    77.280000
76         16.0   4096.0   **147.440001**    99.072002
77         16.0   8192.0   **274.048001**   146.559998
78         16.0  16384.0   **525.600016**   238.399997
79         16.0  32768.0  **1022.639990**   413.792014
80         32.0      1.0    23.600001    76.800004
81         32.0      2.0    23.584001    77.119999
82         32.0      4.0    23.680000    76.959997
83         32.0      8.0    23.744000    75.312003
84         32.0     16.0    24.032000    76.736003
85         32.0     32.0    24.383999    74.975997
86         32.0     64.0    25.664000    75.920001
87         32.0    128.0    27.775999    75.776003
88         32.0    256.0    32.639999    76.991998
89         32.0    512.0    45.120001    74.720003
90         32.0   1024.0    **85.712001**    78.143999
91         32.0   2048.0   **147.456005**    99.200003
92         32.0   4096.0   **273.631990**   146.704003
93         32.0   8192.0   **525.983989**   238.527998
94         32.0  16384.0  **1023.823977**   413.856000
95         32.0  32768.0  **2023.808002**   769.919991
96         64.0      1.0    23.536000    75.999998
97         64.0      2.0    23.712000    75.135998
98         64.0      4.0    23.744000    75.648002
99         64.0      8.0    24.064001    75.648002
100        64.0     16.0    24.383999    75.552002
101        64.0     32.0    25.664000    75.712003
102        64.0     64.0    27.775999    75.135998
103        64.0    128.0    32.575998    75.263999
104        64.0    256.0    45.040000    76.672003
105        64.0    512.0    **85.327998**    77.344000
106        64.0   1024.0   **147.599995**    99.136002
107        64.0   2048.0   **274.080008**   146.623999
108        64.0   4096.0   **525.056005**   238.912001
109        64.0   8192.0  **1024.320006**   413.695991
110        64.0  16384.0  **2019.808054**   769.919991
111        64.0  32768.0  **4007.391930**  1482.624054
112       128.0      1.0    23.647999    77.696003
113       128.0      2.0    23.744000    78.143999
114       128.0      4.0    24.064001    79.616003
115       128.0      8.0    24.400000    78.240000
116       128.0     16.0    25.664000    78.720003
117       128.0     32.0    27.807999    78.847997
118       128.0     64.0    32.416001    78.624003
119       128.0    128.0    45.088001    79.167999
120       128.0    256.0    **85.376002**    79.744004
121       128.0    512.0   **147.295997**    99.104002
122       128.0   1024.0   **273.983985**   146.495998
123       128.0   2048.0   **526.848018**   238.591999
124       128.0   4096.0  **1023.839951**   413.807988
125       128.0   8192.0  **2023.328066**   769.343972
126       128.0  16384.0  **4007.056236**  1481.424093
127       128.0  32768.0  **8052.288055**  2958.911896

H200

✅ CUDA and Triton implementations match
moe-align-block-size-performance:
     batch_size  seq_len          CUDA        Triton
0           1.0      1.0     21.760000     40.800001
1           1.0      2.0     21.760000     40.895998
2           1.0      4.0     21.728000     40.991999
3           1.0      8.0     21.888001     41.312002
4           1.0     16.0     21.919999     41.280001
5           1.0     32.0     22.112001     41.536000
6           1.0     64.0     22.528000     42.656001
7           1.0    128.0     23.040000     42.240001
8           1.0    256.0     23.808001     44.064000
9           1.0    512.0     26.095999     48.448000
10          1.0   1024.0     30.928001     54.079998
11          1.0   2048.0     43.807998     62.240001
12          1.0   4096.0     83.903998     76.079994
13          1.0   8192.0    147.200003     99.232003
14          1.0  16384.0    273.247987    145.536005
15          1.0  32768.0    527.999997    240.400001
16          2.0      1.0     21.760000     40.927999
17          2.0      2.0     21.760000     40.959999
18          2.0      4.0     21.919999     41.216001
19          2.0      8.0     21.919999     41.216001
20          2.0     16.0     22.080000     41.503999
21          2.0     32.0     22.496000     42.624000
22          2.0     64.0     23.072001     42.304002
23          2.0    128.0     23.696000     44.032000
24          2.0    256.0     26.079999     48.735999
25          2.0    512.0     30.880000     54.175999
26          2.0   1024.0     43.807998     62.144000
27          2.0   2048.0     83.935998     76.143995
28          2.0   4096.0    146.752000     99.168003
29          2.0   8192.0    272.864014    145.311996
30          2.0  16384.0    528.352022    240.208000
31          2.0  32768.0   1028.576016    414.303988
32          4.0      1.0     21.760000     41.120000
33          4.0      2.0     21.919999     41.216001
34          4.0      4.0     21.919999     41.343998
35          4.0      8.0     22.112001     41.728001
36          4.0     16.0     22.592001     42.720001
37          4.0     32.0     23.072001     42.240001
38          4.0     64.0     23.808001     44.128001
39          4.0    128.0     26.112000     48.480000
40          4.0    256.0     30.912001     53.920001
41          4.0    512.0     43.776002     62.128000
42          4.0   1024.0     83.999999     76.095998
43          4.0   2048.0    146.975994     99.168003
44          4.0   4096.0    272.992015    145.344004
45          4.0   8192.0    527.487993    240.288004
46          4.0  16384.0   1028.576016    414.144009
47          4.0  32768.0   2034.463882    773.696005
48          8.0      1.0     21.856001     41.216001
49          8.0      2.0     21.919999     41.280001
50          8.0      4.0     22.080000     41.600000
51          8.0      8.0     22.496000     42.624000
52          8.0     16.0     23.008000     42.272002
53          8.0     32.0     23.744000     44.144001
54          8.0     64.0     26.079999     48.512001
55          8.0    128.0     31.040000     54.079998
56          8.0    256.0     43.807998     62.176000
57          8.0    512.0     84.063999     76.095998
58          8.0   1024.0    146.944001     99.072002
59          8.0   2048.0    272.576004    145.408005
60          8.0   4096.0    528.927982    240.768000
61          8.0   8192.0   1028.192043    414.112002
62          8.0  16384.0   2034.240007    774.047971
63          8.0  32768.0   4024.064064   1488.415956
64         16.0      1.0     21.919999     41.343998
65         16.0      2.0     22.112001     41.503999
66         16.0      4.0     22.496000     42.656001
67         16.0      8.0     23.104001     42.240001
68         16.0     16.0     23.808001     43.968000
69         16.0     32.0     26.079999     48.512001
70         16.0     64.0     30.912001     54.079998
71         16.0    128.0     43.807998     62.240001
72         16.0    256.0     84.160000     75.999998
73         16.0    512.0    146.975994     99.136002
74         16.0   1024.0    273.088008    145.503998
75         16.0   2048.0    528.447986    240.416005
76         16.0   4096.0   1028.671980    414.303988
77         16.0   8192.0   2033.983946    774.111986
78         16.0  16384.0   4021.408081   1488.991976
79         16.0  32768.0   8028.175354   2932.768106
80         32.0      1.0     22.080000     41.600000
81         32.0      2.0     22.528000     42.624000
82         32.0      4.0     23.072001     42.272002
83         32.0      8.0     23.808001     44.128001
84         32.0     16.0     26.079999     48.512001
85         32.0     32.0     31.008000     53.920001
86         32.0     64.0     43.807998     62.208001
87         32.0    128.0     84.096000     76.191999
88         32.0    256.0    147.487998     99.136002
89         32.0    512.0    272.639990    145.344004
90         32.0   1024.0    528.223991    240.799993
91         32.0   2048.0   1027.008057    414.175987
92         32.0   4096.0   2032.351971    773.151994
93         32.0   8192.0   4028.880119   1489.567995
94         32.0  16384.0   8026.767731   2924.767971
95         32.0  32768.0  16051.088333   6203.711987
96         64.0      1.0     22.496000     42.784002
97         64.0      2.0     23.104001     42.304002
98         64.0      4.0     23.808001     44.032000
99         64.0      8.0     26.112000     48.624001
100        64.0     16.0     30.975999     53.984001
101        64.0     32.0     43.807998     62.304001
102        64.0     64.0     84.096000     75.871997
103        64.0    128.0    146.880001     99.200003
104        64.0    256.0    272.704005    145.472005
105        64.0    512.0    528.128028    240.240008
106        64.0   1024.0   1029.024005    414.400011
107        64.0   2048.0   2034.080029    772.415996
108        64.0   4096.0   4032.608032   1490.015984
109        64.0   8192.0   8031.744003   2930.560112
110        64.0  16384.0  16027.168274   6208.992004
111        64.0  32768.0  32008.895874  12419.839859
112       128.0      1.0     23.056000     42.399999
113       128.0      2.0     23.808001     44.160001
114       128.0      4.0     26.112000     48.864000
115       128.0      8.0     30.944001     54.143999
116       128.0     16.0     43.807998     62.304001
117       128.0     32.0     83.792001     75.967997
118       128.0     64.0    148.031995     99.136002
119       128.0    128.0    273.600012    145.328000
120       128.0    256.0    527.872026    240.400001
121       128.0    512.0   1028.464079    414.207995
122       128.0   1024.0   2031.935930    772.800028
123       128.0   2048.0   4032.127857   1489.120007
124       128.0   4096.0   8032.719612   2931.519985
125       128.0   8192.0  16011.951447   6202.784061
126       128.0  16384.0  31969.087601  12423.839569

Reproduction

Change the code above and run python3 benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py.

Environment

No Limit.

@tim-zou
Copy link

tim-zou commented Feb 6, 2025

Have you memset the content of token_cnts_buffer to zero? Is it a tensor constructed by zeros_like? Besides, I did not find the script benchmark_deepseekv3_moe_align_blocks in the repo, can you kindly provide this script?

@BBuf
Copy link
Collaborator Author

BBuf commented Feb 6, 2025

Have you memset the content of token_cnts_buffer to zero? Is it a tensor constructed by zeros_like? Besides, I did not find the script benchmark_deepseekv3_moe_align_blocks in the repo, can you kindly provide this script?

I have try, but it failed too. The script exist here.

@tim-zou
Copy link

tim-zou commented Feb 6, 2025

In the benchmark script, both cumsum and token_cnts_buffer are allocated with torch.empty and are therefore uninitialized. Have you tried initializing both of them to zero? Because these tensors are uninitialized, the BFS allocator may reuse memory from previous operations, resulting in arbitrary initial values. Thus, the index rank_post_pad = token_cnt + cumsum[expert_id] may contain an invalid index. I'm not sure about other buffers, but the logic should apply

@BBuf
Copy link
Collaborator Author

BBuf commented Feb 6, 2025

token_cnts_buffer

I can have try, and I will submit a new temp code branch for your review.

@BBuf
Copy link
Collaborator Author

BBuf commented Feb 6, 2025

In the benchmark script, both cumsum and token_cnts_buffer are allocated with torch.empty and are therefore uninitialized. Have you tried initializing both of them to zero? Because these tensors are uninitialized, the BFS allocator may reuse memory from previous operations, resulting in arbitrary initial values. Thus, the index rank_post_pad = token_cnt + cumsum[expert_id] may contain an invalid index. I'm not sure about other buffers, but the logic should apply

Thanks, you are right, I have ignore the tensor init zero, now the benchmark_deepseekv3_moe_align_blocks.py worked well with the change : https://github.com/sgl-project/sglang/compare/temp_1 . I will do a accuracy test for DeepSeek V3 now.

@zhyncs
Copy link
Member

zhyncs commented Feb 6, 2025

In the benchmark script, both cumsum and token_cnts_buffer are allocated with torch.empty and are therefore uninitialized. Have you tried initializing both of them to zero? Because these tensors are uninitialized, the BFS allocator may reuse memory from previous operations, resulting in arbitrary initial values. Thus, the index rank_post_pad = token_cnt + cumsum[expert_id] may contain an invalid index. I'm not sure about other buffers, but the logic should apply

cool thanks for the insight and it works!

@BBuf
Copy link
Collaborator Author

BBuf commented Feb 6, 2025

Solved in #3347 .

@BBuf BBuf closed this as completed Feb 6, 2025
@yiakwy-xpu-ml-framework-team
Copy link
Contributor

@BBuf great work I will rebase my multi-block partition based on this version and make it as simple as possible. Once It is completed hope you can help to review

@yiakwy-xpu-ml-framework-team
Copy link
Contributor

@BBuf @zhyncs could you have a look at this #3382, besides token_cnts_buffer I found last week when submit multi-blocks execution PR, expert_ids should also be zero initialzied for small size of data.

It can be easily reproduced mismatch if expert_ids is not zero initalized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants