diff --git a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh index 51c3650..af3659c 100644 --- a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -20,7 +20,6 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero; const uint32_t problem_id = blockIdx.y; const uint32_t bx = blockIdx.x; - const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; constexpr uint32_t num_stages = 2; constexpr uint32_t num_k_frags = 8; constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity(); @@ -45,8 +44,9 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t w_frag[num_k_frags][num_blocks_n][4]; float y_frag[num_blocks_n][8]; - for (uint32_t i = 0; - i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) { + const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; + const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0; + for (uint32_t i = 0; i < num_steps; ++i) { // init y_frag if (bx == 0) { if constexpr (num_blocks_n == 1) { @@ -335,6 +335,20 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, } } } + + // handle the case where one of the segments needs more steps than this one + // to avoid deadlock + if constexpr (cooperative) { + uint32_t max_segment_size = 0; + for (uint32_t i = 0; i < num_problems; ++i) { + max_segment_size = max(max_segment_size, s[i + 1] - s[i]); + } + + const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16); + for (uint32_t i = 0; i < max_steps - num_steps; ++i) { + grid.sync(); + } + } } } // namespace sgmv diff --git a/tests/test_sgmv.py b/tests/test_sgmv.py index d2cb805..ad645cc 100644 --- a/tests/test_sgmv.py +++ b/tests/test_sgmv.py @@ -51,6 +51,12 @@ def get_lora_lens(bs: int, popularity: str) -> list[int]: a *= alpha lens.append(bs - sum(lens)) return sorted(lens, reverse=True) + if popularity.startswith("skewed"): + if bs < 3: + return [bs] + # Create a highly imbalanced distribution by setting the first segment + # length to 1 and the remainder to the second segment. + return [1, bs - 1] raise KeyError(popularity) @@ -81,7 +87,7 @@ def lora_ref_impl( pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")), ], ) -@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical"]) +@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical", "skewed"]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133]) @torch.inference_mode() def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size):