Skip to content

Latest commit

 

History

History

flash_attention

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

FlashAttention

Using tile-lang, we can define buffers at different memory layers. For instance, Q_shared, K_shared, and V_shared can be defined in shared memory, while acc_s and acc_o can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way.

@T.prim_func
def flash_attention(
    Q: T.Buffer(shape, dtype),
    K: T.Buffer(shape, dtype),
    V: T.Buffer(shape, dtype),
    Output: T.Buffer(shape, dtype),
):
    # Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
    #   bx: block index in sequence dimension
    #   by: block index in "heads" dimension
    #   bz: block index in "batch" dimension
    # threads=thread_num means how many threads per block
    with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
        # Allocate shared memory for Q, K, V to reduce global memory accesses
        Q_shared = T.alloc_shared([block_M, dim], dtype)
        K_shared = T.alloc_shared([block_N, dim], dtype)
        V_shared = T.alloc_shared([block_N, dim], dtype)
        # Allocate buffers on register
        # acc_s: buffer to hold intermediate attention scores
        acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
        # acc_s_cast: buffer for storing casted/adjusted scores
        acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
        # acc_o: partial accumulation of output
        acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
        # Buffers to track per-row maximum score and related stats
        scores_max = T.alloc_fragment([block_M], accum_dtype)
        scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
        scores_scale = T.alloc_fragment([block_M], accum_dtype)
        scores_sum = T.alloc_fragment([block_M], accum_dtype)
        logsum = T.alloc_fragment([block_M], accum_dtype)

        # Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access
        T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})

        # Copy a block of Q from global memory to Q_shared
        T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)

        # Initialize accumulators
        T.fill(acc_o, 0)
        T.fill(logsum, 0)
        T.fill(scores_max, -T.infinity(accum_dtype))
        loop_range = (
            T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
        )

        # Pipeline the loop to overlap copies/gemm stages
        for k in T.Pipelined(loop_range, num_stages=num_stages):
            # Copy K block into shared memory
            T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)

            if is_causal:
                for i, j in T.Parallel(block_M, block_N):
                    acc_s[i, j] = T.if_then_else(
                        bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
                    )
            else:
                T.clear(acc_s)

            # Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed,
            # policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row
            # of acc_s, and the resulting acc_s is retained in registers.
            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

            # Copy V block into shared memory
            T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
            for i, j in T.Parallel(block_M, dim):
                acc_s[i, j] *= scale

            # Save old scores_max, then reset scores_max
            T.copy(scores_max, scores_max_prev)
            T.fill(scores_max, -T.infinity(accum_dtype))

            # Compute the maximum value per row on dimension 1 (block_N)
            T.reduce_max(acc_s, scores_max, dim=1, clear=False)

            # Compute the factor by which we need to rescale previous partial sums
            for i in T.Parallel(block_M):
                scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])

            # Rescale the partial output accumulation to keep exponents consistent
            for i, j in T.Parallel(block_M, dim):
                acc_o[i, j] *= scores_scale[i]

            # Exponentiate (scores - max) for the new block
            for i, j in T.Parallel(block_M, block_N):
                acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])

            # Make a cast of acc_s to fp16 for the next GEMM
            T.copy(acc_s, acc_s_cast)

            # Multiply the attention acc_s_cast by V and add to partial output (acc_o)
            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            T.reduce_sum(acc_s, scores_sum, dim=1)
            # Update the "logsum" tracker with the newly accumulated sum
            for i in T.Parallel(block_M):
                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]

        # Final step: divide each partial output by logsum (completing the softmax)
        for i, j in T.Parallel(block_M, dim):
            acc_o[i, j] /= logsum[i]

        # Write back the final output block from acc_o to the Output buffer
        T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])