-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoid ublk tma out bound access #3917
base: main
Are you sure you want to change the base?
Conversation
…zation' into llu/tma_predicate
…nto llu/tma_predicate
Review updated until commit 48bedc1 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
!test |
!test |
!test |
Background
CpAsyncBulkTensorTile
handles out of bound accesses automatically in hardware, no need to predicate it.However,
CpAsyncBulk
leads to illegal memory access if out of bound access happens.ref
the memory range [srcMem, srcMem + size - 1] must not overflow the source memory space. Otherwise, the behavior is undefined
It happens in persistent kernel with non-divisible circular buffer stages and/or non-divisible CTA count. For example, tensor [M, N] is split as:
[sm_count, M/stages/sm_count, stages, N]
and parallelized as:[BIDx, Serial, Serial, Bulk]
. The TMA load is nested within two for-loops, one for[I0/stages/sm_count]
and the other for[stages]
, since predicate is not generated for TMA load, out of bound access happens if any of the split is not disvisible.Candidate Solutions
(1) Add a predicate for CpAsyncBulk:
I explored this approach in PR #3899. However, implementing it becomes challenging in cases involving a circular buffer, as we need to predicate both
CpAsyncBulk
andwaitParity
to avoid potential deadlocks.(2) Modulo the source address calculation for CpAsyncBulk (this PR)
When calculating the source address, we can apply a modulo operation using the tensor size. For a tensor of shape [M, N], each
CpAsyncBulk
loads one row. If the computed source address corresponds to(M + x) * N
, it causes an overflow. After mod(M×N), it changes to load fromx * N
. While this load is redundant, it occurs only in the last iteration of the loop on SMs that otherwise would stays idle, thus has minimal impact on other valid loads and computations across other SMs.