Skip to content

Commit

Permalink
Implements TPU decoding in Pallas (#1039)
Browse files Browse the repository at this point in the history
* Implements TPU decoding in Pallas

* Add support for shorter kv len
  • Loading branch information
hanzhi713 authored Mar 7, 2025
1 parent b1e7b37 commit 3abf229
Show file tree
Hide file tree
Showing 12 changed files with 747 additions and 286 deletions.
50 changes: 40 additions & 10 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,6 +2090,44 @@ def default_key_scale_config() -> InstantiableConfig[ScaleFn]:
return config_for_function(constant_scale_fn).set(value=1)


def compute_gqa_logits(q_proj: Tensor, k_proj: Tensor) -> Tensor:
"""Compute attention logits.
Args:
q_proj: query tensor, [batch, target_length, num_heads, per_head_dim].
k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim].
Returns:
logits: [batch, num_heads, target_length, source_length].
"""
kv_heads = k_proj.shape[2]
num_head_group = q_proj.shape[2] // kv_heads
assert q_proj.shape[2] % kv_heads == 0
q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group)
k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h")
logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj)
return einops.rearrange(logits, "b k g t s -> b (k g) t s")


def compute_gqa_context(probs: Tensor, v_proj: Tensor) -> Tensor:
"""Compute attention context.
Args:
probs: probs tensor, [batch, num_heads, target_length, source_length].
v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim].
Returns:
context: [batch, target_length, num_heads, per_head_dim].
"""
kv_heads = v_proj.shape[2]
num_head_group = probs.shape[1] // kv_heads
assert probs.shape[1] % kv_heads == 0
probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group)
v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h")
context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj)
return einops.rearrange(context, "b t k g h -> b t (k g) h")


class GroupedQueryAttention(MultiheadAttention):
"""A Grouped-Query Attention (GQA) layer.
Expand Down Expand Up @@ -2128,12 +2166,7 @@ def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor:
if num_head_group == 1:
return super()._compute_logits(q_proj=q_proj, k_proj=k_proj)

q_proj = self.scale_query(q_proj)
k_proj = self.scale_key(k_proj)
q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group)
k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h")
logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj)
return einops.rearrange(logits, "b k g t s -> b (k g) t s")
return compute_gqa_logits(self.scale_query(q_proj), self.scale_key(k_proj))

def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor:
"""Compute attention context.
Expand All @@ -2150,10 +2183,7 @@ def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor:
if num_head_group == 1:
return super()._compute_context(probs=probs, v_proj=v_proj)

probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group)
v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h")
context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj)
return einops.rearrange(context, "b t k g h -> b t (k g) h")
return compute_gqa_context(probs, v_proj)


class SigmoidAttention(MultiheadAttention):
Expand Down
88 changes: 88 additions & 0 deletions axlearn/common/flash_attention/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright © 2025 Apple Inc.
"""Common utilities across backends."""

from typing import NamedTuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl

from axlearn.common.attention_bias import MaskFn
from axlearn.common.utils import Tensor


def build_mask(
mask_fn: MaskFn, *, q_seq_len: int, kv_seq_len: int, block_q: int, block_k: int
) -> np.ndarray:
"""Builds the block map where True means the block is not fully masked.
Args:
mask_fn: The attention mask function.
q_seq_len: Query sequence length.
kv_seq_len: Key/Value sequence length.
block_q: Query block size.
block_k: Key/Value block size.
Returns:
A boolean array of shape (num_q_blocks, num_kv_blocks) where True means the block is not
fully masked. num_q_blocks * block_q will be larger than q_seq_len if q_seq_len is not
divisible by block_q. The same holds true for kv blocks.
"""
# Initialize the iteration map where True means the block is not empty.
num_q_blocks = pl.cdiv(q_seq_len, block_q)
num_kv_blocks = pl.cdiv(kv_seq_len, block_k)
block_mask_map = np.ones(shape=(num_q_blocks, num_kv_blocks), dtype=np.bool_)
# # Initialize the scan begin and end indices.
rows = np.arange(q_seq_len, dtype=np.int32)
cols = np.arange(kv_seq_len, dtype=np.int32)
# Run a compile-time evaluation to get the mask array.
# TODO(kelvin-zou): use a block-wise mask function to avoid the compile-time
# high memory usage.
with jax.ensure_compile_time_eval():
mask_array = np.asarray(mask_fn(rows[:, None], cols[None, :]))
for i in range(0, q_seq_len, block_q):
for j in range(0, kv_seq_len, block_k):
# Extract the block
block = mask_array[i : i + block_q, j : j + block_k]
# All empty means skipping
if not block.any():
block_mask_map[i // block_q, j // block_k] = False
return block_mask_map


class KVOffsetInfo(NamedTuple):
"""Records the block index of non-empty KV blocks.
Attributes:
kv_block_offset: A (num_q_blocks, num_kv_blocks) tensor where `kv_block_offset[i][j]`
stores the index of the jth non-empty KV block index for the ith query block.
This tensor may be padded at the end.
kv_block_offset_size: A (num_q_blocks,) tensor that stores the number of valid entries
for each row of `kv_block_offset`, i.e. the number of entries before padding.
"""

kv_block_offset: Tensor
kv_block_offset_size: Tensor


def query_iterator_indices(block_mask_map: np.ndarray, *, padding: int = 0) -> KVOffsetInfo:
"""Builds `KVOffsetInfo` for block-sparse attention computation in the forward pass.
Returns:
A `KVOffsetInfo`. See the attributes of `KVOffsetInfo` for more info.
"""
num_q_blocks, num_kv_blocks = block_mask_map.shape
index_offset = np.full((num_q_blocks, num_kv_blocks), padding, dtype=np.int32)
index_offset_size = np.zeros(shape=(num_q_blocks), dtype=np.int32)
for i in range(num_q_blocks):
k = 0
for j in range(num_kv_blocks):
if block_mask_map[i, j]:
index_offset[i, k] = j
k += 1
index_offset_size[i] = k
return KVOffsetInfo(
kv_block_offset=jnp.asarray(index_offset),
kv_block_offset_size=jnp.asarray(index_offset_size),
)
148 changes: 148 additions & 0 deletions axlearn/common/flash_attention/decoding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright © 2025 Apple Inc.
"""Tests GPU and TPU decoding."""
from contextlib import nullcontext
from typing import Literal

import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

from axlearn.common.attention_bias import sliding_window_causal_mask
from axlearn.common.flash_attention.gpu_decoding import NEG_INF
from axlearn.common.flash_attention.gpu_decoding import flash_decoding as gpu_decoding
from axlearn.common.flash_attention.tpu_decoding import tpu_decoding
from axlearn.common.flash_attention.utils import mha_reference
from axlearn.common.test_utils import TestCase, Tolerance

if jax.default_backend() == "gpu":
decoding_fns = [gpu_decoding]
dtypes = [jnp.float32, jnp.float16]
elif jax.default_backend() == "tpu":
decoding_fns = [tpu_decoding]
dtypes = [jnp.bfloat16]
elif jax.default_backend() == "cpu":
# CPU emulation of pallas kernels.
decoding_fns = [gpu_decoding, tpu_decoding]
dtypes = [jnp.float32]
else:
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


class DecodingTest(TestCase):
"""Tests GPU and TPU decoding."""

@parameterized.product(
[
dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args))
for args in [
(1, 1024, 32, 64),
(4, 512, 48, 64),
(2, 1024, 16, 128),
(1, 4096, 8, 128),
(2, 734, 48, 64),
]
],
attention_bias_type=[None],
input_dtype=dtypes,
padding=[0, 111],
kv_head_factor=[1, 4, 8],
window_len=[-1, 16, 127],
decoding_fn=decoding_fns,
)
def test_decode_against_ref(
self,
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
attention_bias_type: Literal["2d", "4d", None],
input_dtype: jnp.dtype,
padding: int,
kv_head_factor: int,
window_len: int,
decoding_fn,
):
if seq_len % 512 != 0 and decoding_fn is tpu_decoding:
self.skipTest("TPU decoding doesn't support seq_len % block_size != 0")
self.assertEqual(num_heads % kv_head_factor, 0)
assert num_heads % kv_head_factor == 0
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(
k2,
(batch_size, seq_len, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)
v = jax.random.normal(
k3,
(batch_size, seq_len, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)

if attention_bias_type == "4d":
bias = jax.random.normal(k4, (batch_size, num_heads, 1, seq_len), dtype=input_dtype)
elif attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, 1, seq_len), dtype=input_dtype)
else:
bias = None

softmax_scale = per_head_dim**0.5
mask_fn = None
if window_len > 0:
mask_fn = sliding_window_causal_mask(window_len)
o = decoding_fn(
q,
k,
v,
bias=bias,
softmax_scale=softmax_scale,
kv_seq_len=seq_len - padding,
mask_fn=mask_fn,
interpret=(jax.default_backend() == "cpu"),
)
if bias is not None:
bias = bias[:, :, :, : seq_len - padding]
if window_len > 0:
if bias is None:
bias = jnp.zeros((1, 1, 1, seq_len - padding), dtype=input_dtype)
bias = bias.at[:, :, :, : -window_len - 1].set(NEG_INF)
with jax.default_matmul_precision(
"highest"
) if input_dtype is jnp.float32 else nullcontext():
o_ref = mha_reference(
q,
k[:, : seq_len - padding],
v[:, : seq_len - padding],
bias,
None,
causal=False,
softmax_scale=softmax_scale,
)
if input_dtype is jnp.float32:
self.assertNestedAllClose(o, o_ref, rtol=0.001, atol=0.0005)
# bfloat16 and float16 have occasional outliers that require relaxed tolerances.
elif input_dtype is jnp.bfloat16:
self.assertAllCloseWithOutliers(
o,
o_ref,
tolerance_map={
1.0: Tolerance(rtol=0.05, atol=1.25),
0.99: Tolerance(rtol=0.05, atol=0.4),
0.95: Tolerance(rtol=0.05, atol=0.2),
0.9: Tolerance(rtol=0.05, atol=0.1),
0.8: Tolerance(rtol=0.05, atol=0.05),
},
)
elif input_dtype is jnp.float16:
self.assertAllCloseWithOutliers(
o,
o_ref,
tolerance_map={
1.0: Tolerance(rtol=0.01, atol=0.2),
0.98: Tolerance(rtol=0.01, atol=0.05),
0.9: Tolerance(rtol=0.01, atol=0.025),
},
)
else:
raise ValueError(f"Unsupported dtype {input_dtype}")
Loading

0 comments on commit 3abf229

Please sign in to comment.