From 9192e6de683cd64d80437f567441042d0858a11a Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Fri, 7 Mar 2025 14:33:28 -0800 Subject: [PATCH] Implements TPU decoding in Pallas (#1039) * Implements TPU decoding in Pallas * Add support for shorter kv len --- axlearn/common/attention.py | 50 ++- axlearn/common/flash_attention/common.py | 88 ++++++ .../common/flash_attention/decoding_test.py | 148 +++++++++ .../common/flash_attention/gpu_attention.py | 63 +--- .../gpu_attention_benchmark.py | 15 +- .../flash_attention/gpu_attention_test.py | 100 +----- .../common/flash_attention/gpu_decoding.py | 9 +- axlearn/common/flash_attention/layer_test.py | 47 ++- .../tpu_attention_benchmark.py | 108 +++---- .../common/flash_attention/tpu_decoding.py | 284 ++++++++++++++++++ axlearn/common/flash_attention/utils.py | 75 +++-- axlearn/common/test_utils.py | 46 ++- 12 files changed, 747 insertions(+), 286 deletions(-) create mode 100644 axlearn/common/flash_attention/common.py create mode 100644 axlearn/common/flash_attention/decoding_test.py create mode 100644 axlearn/common/flash_attention/tpu_decoding.py diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index f05ac1a70..5abbe0748 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -2090,6 +2090,44 @@ def default_key_scale_config() -> InstantiableConfig[ScaleFn]: return config_for_function(constant_scale_fn).set(value=1) +def compute_gqa_logits(q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. + + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ + kv_heads = k_proj.shape[2] + num_head_group = q_proj.shape[2] // kv_heads + assert q_proj.shape[2] % kv_heads == 0 + q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group) + k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h") + logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj) + return einops.rearrange(logits, "b k g t s -> b (k g) t s") + + +def compute_gqa_context(probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + kv_heads = v_proj.shape[2] + num_head_group = probs.shape[1] // kv_heads + assert probs.shape[1] % kv_heads == 0 + probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group) + v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h") + context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj) + return einops.rearrange(context, "b t k g h -> b t (k g) h") + + class GroupedQueryAttention(MultiheadAttention): """A Grouped-Query Attention (GQA) layer. @@ -2128,12 +2166,7 @@ def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: if num_head_group == 1: return super()._compute_logits(q_proj=q_proj, k_proj=k_proj) - q_proj = self.scale_query(q_proj) - k_proj = self.scale_key(k_proj) - q_proj = einops.rearrange(q_proj, "b t (k g) h -> b t k g h", k=kv_heads, g=num_head_group) - k_proj = einops.rearrange(k_proj, "b s k h -> b s k 1 h") - logits = jnp.einsum("btkgh,bsk1h->bkgts", q_proj, k_proj) - return einops.rearrange(logits, "b k g t s -> b (k g) t s") + return compute_gqa_logits(self.scale_query(q_proj), self.scale_key(k_proj)) def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: """Compute attention context. @@ -2150,10 +2183,7 @@ def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: if num_head_group == 1: return super()._compute_context(probs=probs, v_proj=v_proj) - probs = einops.rearrange(probs, "b (k g) t s -> b k g t s", k=kv_heads, g=num_head_group) - v_proj = einops.rearrange(v_proj, "b s k h -> b s k 1 h") - context = jnp.einsum("bkgts,bsk1h->btkgh", probs, v_proj) - return einops.rearrange(context, "b t k g h -> b t (k g) h") + return compute_gqa_context(probs, v_proj) class SigmoidAttention(MultiheadAttention): diff --git a/axlearn/common/flash_attention/common.py b/axlearn/common/flash_attention/common.py new file mode 100644 index 000000000..46d1a5ba0 --- /dev/null +++ b/axlearn/common/flash_attention/common.py @@ -0,0 +1,88 @@ +# Copyright © 2025 Apple Inc. +"""Common utilities across backends.""" + +from typing import NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np +from jax.experimental import pallas as pl + +from axlearn.common.attention_bias import MaskFn +from axlearn.common.utils import Tensor + + +def build_mask( + mask_fn: MaskFn, *, q_seq_len: int, kv_seq_len: int, block_q: int, block_k: int +) -> np.ndarray: + """Builds the block map where True means the block is not fully masked. + + Args: + mask_fn: The attention mask function. + q_seq_len: Query sequence length. + kv_seq_len: Key/Value sequence length. + block_q: Query block size. + block_k: Key/Value block size. + + Returns: + A boolean array of shape (num_q_blocks, num_kv_blocks) where True means the block is not + fully masked. num_q_blocks * block_q will be larger than q_seq_len if q_seq_len is not + divisible by block_q. The same holds true for kv blocks. + """ + # Initialize the iteration map where True means the block is not empty. + num_q_blocks = pl.cdiv(q_seq_len, block_q) + num_kv_blocks = pl.cdiv(kv_seq_len, block_k) + block_mask_map = np.ones(shape=(num_q_blocks, num_kv_blocks), dtype=np.bool_) + # # Initialize the scan begin and end indices. + rows = np.arange(q_seq_len, dtype=np.int32) + cols = np.arange(kv_seq_len, dtype=np.int32) + # Run a compile-time evaluation to get the mask array. + # TODO(kelvin-zou): use a block-wise mask function to avoid the compile-time + # high memory usage. + with jax.ensure_compile_time_eval(): + mask_array = np.asarray(mask_fn(rows[:, None], cols[None, :])) + for i in range(0, q_seq_len, block_q): + for j in range(0, kv_seq_len, block_k): + # Extract the block + block = mask_array[i : i + block_q, j : j + block_k] + # All empty means skipping + if not block.any(): + block_mask_map[i // block_q, j // block_k] = False + return block_mask_map + + +class KVOffsetInfo(NamedTuple): + """Records the block index of non-empty KV blocks. + + Attributes: + kv_block_offset: A (num_q_blocks, num_kv_blocks) tensor where `kv_block_offset[i][j]` + stores the index of the jth non-empty KV block index for the ith query block. + This tensor may be padded at the end. + kv_block_offset_size: A (num_q_blocks,) tensor that stores the number of valid entries + for each row of `kv_block_offset`, i.e. the number of entries before padding. + """ + + kv_block_offset: Tensor + kv_block_offset_size: Tensor + + +def query_iterator_indices(block_mask_map: np.ndarray, *, padding: int = 0) -> KVOffsetInfo: + """Builds `KVOffsetInfo` for block-sparse attention computation in the forward pass. + + Returns: + A `KVOffsetInfo`. See the attributes of `KVOffsetInfo` for more info. + """ + num_q_blocks, num_kv_blocks = block_mask_map.shape + index_offset = np.full((num_q_blocks, num_kv_blocks), padding, dtype=np.int32) + index_offset_size = np.zeros(shape=(num_q_blocks), dtype=np.int32) + for i in range(num_q_blocks): + k = 0 + for j in range(num_kv_blocks): + if block_mask_map[i, j]: + index_offset[i, k] = j + k += 1 + index_offset_size[i] = k + return KVOffsetInfo( + kv_block_offset=jnp.asarray(index_offset), + kv_block_offset_size=jnp.asarray(index_offset_size), + ) diff --git a/axlearn/common/flash_attention/decoding_test.py b/axlearn/common/flash_attention/decoding_test.py new file mode 100644 index 000000000..8db6ebfbb --- /dev/null +++ b/axlearn/common/flash_attention/decoding_test.py @@ -0,0 +1,148 @@ +# Copyright © 2025 Apple Inc. +"""Tests GPU and TPU decoding.""" +from contextlib import nullcontext +from typing import Literal + +import jax +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +from axlearn.common.attention_bias import sliding_window_causal_mask +from axlearn.common.flash_attention.gpu_decoding import NEG_INF +from axlearn.common.flash_attention.gpu_decoding import flash_decoding as gpu_decoding +from axlearn.common.flash_attention.tpu_decoding import tpu_decoding +from axlearn.common.flash_attention.utils import mha_reference +from axlearn.common.test_utils import TestCase, Tolerance + +if jax.default_backend() == "gpu": + decoding_fns = [gpu_decoding] + dtypes = [jnp.float32, jnp.float16] +elif jax.default_backend() == "tpu": + decoding_fns = [tpu_decoding] + dtypes = [jnp.bfloat16] +elif jax.default_backend() == "cpu": + # CPU emulation of pallas kernels. + decoding_fns = [gpu_decoding, tpu_decoding] + dtypes = [jnp.float32] +else: + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + + +class DecodingTest(TestCase): + """Tests GPU and TPU decoding.""" + + @parameterized.product( + [ + dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args)) + for args in [ + (1, 1024, 32, 64), + (4, 512, 48, 64), + (2, 1024, 16, 128), + (1, 4096, 8, 128), + (2, 734, 48, 64), + ] + ], + attention_bias_type=[None], + input_dtype=dtypes, + padding=[0, 111], + kv_head_factor=[1, 4, 8], + window_len=[-1, 16, 127], + decoding_fn=decoding_fns, + ) + def test_decode_against_ref( + self, + batch_size: int, + seq_len: int, + num_heads: int, + per_head_dim: int, + attention_bias_type: Literal["2d", "4d", None], + input_dtype: jnp.dtype, + padding: int, + kv_head_factor: int, + window_len: int, + decoding_fn, + ): + if seq_len % 512 != 0 and decoding_fn is tpu_decoding: + self.skipTest("TPU decoding doesn't support seq_len % block_size != 0") + self.assertEqual(num_heads % kv_head_factor, 0) + assert num_heads % kv_head_factor == 0 + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) + q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype) + k = jax.random.normal( + k2, + (batch_size, seq_len, num_heads // kv_head_factor, per_head_dim), + dtype=input_dtype, + ) + v = jax.random.normal( + k3, + (batch_size, seq_len, num_heads // kv_head_factor, per_head_dim), + dtype=input_dtype, + ) + + if attention_bias_type == "4d": + bias = jax.random.normal(k4, (batch_size, num_heads, 1, seq_len), dtype=input_dtype) + elif attention_bias_type == "2d": + bias = jax.random.normal(k4, (1, 1, 1, seq_len), dtype=input_dtype) + else: + bias = None + + softmax_scale = per_head_dim**0.5 + mask_fn = None + if window_len > 0: + mask_fn = sliding_window_causal_mask(window_len) + o = decoding_fn( + q, + k, + v, + bias=bias, + softmax_scale=softmax_scale, + kv_seq_len=seq_len - padding, + mask_fn=mask_fn, + interpret=(jax.default_backend() == "cpu"), + ) + if bias is not None: + bias = bias[:, :, :, : seq_len - padding] + if window_len > 0: + if bias is None: + bias = jnp.zeros((1, 1, 1, seq_len - padding), dtype=input_dtype) + bias = bias.at[:, :, :, : -window_len - 1].set(NEG_INF) + with jax.default_matmul_precision( + "highest" + ) if input_dtype is jnp.float32 else nullcontext(): + o_ref = mha_reference( + q, + k[:, : seq_len - padding], + v[:, : seq_len - padding], + bias, + None, + causal=False, + softmax_scale=softmax_scale, + ) + if input_dtype is jnp.float32: + self.assertNestedAllClose(o, o_ref, rtol=0.001, atol=0.0005) + # bfloat16 and float16 have occasional outliers that require relaxed tolerances. + elif input_dtype is jnp.bfloat16: + self.assertAllCloseWithOutliers( + o, + o_ref, + tolerance_map={ + 1.0: Tolerance(rtol=0.05, atol=1.25), + 0.99: Tolerance(rtol=0.05, atol=0.4), + 0.95: Tolerance(rtol=0.05, atol=0.2), + 0.9: Tolerance(rtol=0.05, atol=0.1), + 0.8: Tolerance(rtol=0.05, atol=0.05), + }, + ) + elif input_dtype is jnp.float16: + self.assertAllCloseWithOutliers( + o, + o_ref, + tolerance_map={ + 1.0: Tolerance(rtol=0.01, atol=0.2), + 0.98: Tolerance(rtol=0.01, atol=0.05), + 0.9: Tolerance(rtol=0.01, atol=0.025), + }, + ) + else: + raise ValueError(f"Unsupported dtype {input_dtype}") diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 720b7200a..1581c7eec 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -43,6 +43,7 @@ from jax.experimental import pallas as pl from axlearn.common.attention_bias import NEG_INF, MaskFn +from axlearn.common.flash_attention.common import build_mask, query_iterator_indices from axlearn.common.flash_attention.remat import FLASH_ATTN_RESIDUAL_NAME from axlearn.common.layers import get_dropout_mask from axlearn.common.utils import Tensor @@ -74,60 +75,6 @@ def _segment_mask( return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) -def _build_mask( - mask_fn: MaskFn, *, q_seq_len: int, kv_seq_len: int, block_q: int, block_k: int -) -> np.ndarray: - """build the iteration map where True means the block is not empty. - - Returns: - A boolean array of shape (num_q_blocks, num_kv_blocks) where True means - the block is not empty. - """ - # Initialize the iteration map where True means the block is not empty. - num_q_blocks = pl.cdiv(q_seq_len, block_q) - num_kv_blocks = pl.cdiv(kv_seq_len, block_k) - block_mask_map = np.ones(shape=(num_q_blocks, num_kv_blocks), dtype=np.bool_) - # # Initialize the scan begin and end indices. - rows = np.arange(q_seq_len, dtype=np.int32) - cols = np.arange(kv_seq_len, dtype=np.int32) - # Run a compile-time evaluation to get the mask array. - # TODO(kelvin-zou): use a block-wise mask function to avoid the compile-time - # high memory usage. - with jax.ensure_compile_time_eval(): - mask_array = np.asarray(mask_fn(rows[:, None], cols[None, :])) - for i in range(0, q_seq_len, block_q): - for j in range(0, kv_seq_len, block_k): - # Extract the block - block = mask_array[i : i + block_q, j : j + block_k] - # All empty means skipping - if not block.any(): - block_mask_map[i // block_q, j // block_k] = False - return block_mask_map - - -def _query_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor]: - """build the iteration begin/end indices for the query dimension. - - Returns: - Index_offset (num_q_blocks, num_kv_blocks) tensor where index_offset[i][j] - to store the first jth available block index for ith query block, and the unused - blocks are padded with 0 at the very end. - Index_offset_size ((num_q_blocks) tensor to store the number of valid blocks - for each iteration. - """ - num_q_blocks, num_kv_blocks = block_mask_map.shape - index_offset = np.zeros(shape=(num_q_blocks, num_kv_blocks), dtype=np.int32) - index_offset_size = np.zeros(shape=(num_q_blocks), dtype=np.int32) - for i in range(num_q_blocks): - k = 0 - for j in range(num_kv_blocks): - if block_mask_map[i, j]: - index_offset[i, k] = j - k += 1 - index_offset_size[i] = k - return jnp.asarray(index_offset), jnp.asarray(index_offset_size) - - def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Tensor]: """build the iteration begin/end indices for the key/value dimension. @@ -419,10 +366,10 @@ def _flash_attention_impl( in_specs.append(None) index_offset = index_offset_spec = index_offset_size = index_offset_size_spec = None if mask_fn is not None: - block_mask_array = _build_mask( + block_mask_array = build_mask( mask_fn, q_seq_len=q_seq_len, kv_seq_len=kv_seq_len, block_q=block_q, block_k=block_k ) - index_offset, index_offset_size = _query_iterator_indices(block_mask_array) + index_offset, index_offset_size = query_iterator_indices(block_mask_array) num_kv_blocks = pl.cdiv(kv_seq_len, block_k) index_offset_spec = pl.BlockSpec( index_map=(lambda i, _, k: (i, 0)), block_shape=((None, num_kv_blocks)) @@ -705,11 +652,11 @@ def _mha_backward( q_index_offset = q_index_offset_spec = q_index_offset_size = q_index_offset_size_spec = None kv_index_offset = kv_index_offset_spec = kv_index_offset_size = kv_index_offset_size_spec = None if mask_fn is not None: - block_mask_array = _build_mask( + block_mask_array = build_mask( mask_fn, q_seq_len=q_seq_len, kv_seq_len=kv_seq_len, block_q=block_q, block_k=block_k ) # Compute the dynamic indices for the query for dq. - q_index_offset, q_index_offset_size = _query_iterator_indices(block_mask_array) + q_index_offset, q_index_offset_size = query_iterator_indices(block_mask_array) q_index_offset_spec = pl.BlockSpec( index_map=(lambda i, _, k: (k, 0)), block_shape=((None, num_kv_blocks)) ) diff --git a/axlearn/common/flash_attention/gpu_attention_benchmark.py b/axlearn/common/flash_attention/gpu_attention_benchmark.py index 6178d8d03..44af0ad9b 100644 --- a/axlearn/common/flash_attention/gpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/gpu_attention_benchmark.py @@ -234,13 +234,18 @@ def bench_flash_attention( ) # Bias is not supported in pallas, so we don't include it here. bias = None - if sw_sz != -1 and not is_decode: + if sw_sz != -1: mask_fn = sliding_window_causal_mask(sw_sz) # We convert mask into a bias tensor for jax and cudnn. assert bias is None - mask = mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]) - bias = jnp.zeros((1, 1, seq_len, seq_len), dtype=jnp.float16) - bias = jnp.where(mask, bias, NEG_INF) + if not is_decode: + bias = jnp.zeros((1, 1, seq_len, seq_len), dtype=jnp.float16) + bias = jnp.where( + mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]), bias, NEG_INF + ) + else: + bias = jnp.zeros((1, 1, 1, seq_len), dtype=jnp.float16) + bias = bias.at[:, :, :, :-sw_sz].set(NEG_INF) else: mask_fn = causal_mask if "axlearn" in library: @@ -287,8 +292,6 @@ def cudnn_fn(q, k, v, bias): else: fn = partial(cudnn_dot_product_attention, causal=not is_decode) else: - k = k.repeat(num_heads // num_kv_heads, axis=2) - v = v.repeat(num_heads // num_kv_heads, axis=2) args = (q, k, v, bias) if use_bwd: diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 5b1262d5c..d393d1f7a 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -17,16 +17,14 @@ import jax import jax.numpy as jnp import pytest -from absl.testing import parameterized from axlearn.common.attention_bias import causal_mask, sliding_window_causal_mask from axlearn.common.flash_attention.gpu_attention import ( cudnn_dot_product_attention, flash_attention, ) -from axlearn.common.flash_attention.gpu_decoding import NEG_INF, flash_decoding -from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference -from axlearn.common.test_utils import TestCase +from axlearn.common.flash_attention.gpu_decoding import NEG_INF +from axlearn.common.flash_attention.utils import mha_reference if jax.default_backend() not in ("gpu", "cpu"): pytest.skip(reason="Incompatible hardware", allow_module_level=True) @@ -140,100 +138,6 @@ def call_flash(q, k, v, bias, segment_ids, k5): chex.assert_trees_all_close(o, o_ref, atol=0.03) -class FlashDecodingTest(TestCase): - """Tests FlashDecoding.""" - - @parameterized.product( - [ - dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args)) - for args in [ - (1, 1024, 32, 64), - (1, 444, 16, 64), - (8, 1596, 48, 128), - (8, 4044, 64, 128), - ] - ], - softmax_scale=[1.0, 0.83], - attention_bias_type=["2d", "4d", None], - input_dtype=[jnp.float32, jnp.float16], - padding=[0, 111], - kv_head_factor=[1, 4, 8], - window_len=[-1, 16, 127], - ) - def test_decode_against_ref( - self, - batch_size: int, - seq_len: int, - num_heads: int, - per_head_dim: int, - softmax_scale: float, - attention_bias_type: Literal["2d", "4d", None], - input_dtype: jnp.dtype, - padding: int, - kv_head_factor: int, - window_len: int, - ): - if jax.default_backend() == "cpu" and seq_len >= 512: - pytest.skip(reason="Too slow on CPU.") - self.assertEqual(num_heads % kv_head_factor, 0) - assert num_heads % kv_head_factor == 0 - k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) - q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype) - k = jax.random.normal( - k2, - (batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim), - dtype=input_dtype, - ) - v = jax.random.normal( - k3, - (batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim), - dtype=input_dtype, - ) - - if attention_bias_type == "4d": - bias = jax.random.normal( - k4, (batch_size, num_heads, 1, seq_len + padding), dtype=input_dtype - ) - elif attention_bias_type == "2d": - bias = jax.random.normal(k4, (1, 1, 1, seq_len + padding), dtype=input_dtype) - else: - bias = None - - mask_fn = None - if window_len > 0: - mask_fn = sliding_window_causal_mask(window_len) - o = flash_decoding( - q, - k, - v, - bias=bias, - softmax_scale=softmax_scale, - kv_seq_len=seq_len, - mask_fn=mask_fn, - interpret=(jax.default_backend() == "cpu"), - ) - if bias is not None: - bias = bias[:, :, :, :seq_len] - if window_len > 0: - if bias is None: - bias = jnp.zeros((1, 1, 1, seq_len), dtype=input_dtype) - bias = bias.at[:, :, :, : -window_len - 1].set(NEG_INF) - o_ref = mha_reference( - q, - _repeat_kv_heads(num_heads, k[:, :seq_len]), - _repeat_kv_heads(num_heads, v[:, :seq_len]), - bias, - None, - causal=False, - softmax_scale=softmax_scale, - ) - self.assertGreaterEqual(jnp.median(jnp.abs(o_ref)).item(), 0.25) - if input_dtype is jnp.float32: - self.assertNestedAllClose(o, o_ref, rtol=0.01, atol=0.01) - else: - self.assertNestedAllClose(o, o_ref, rtol=0.05, atol=0.05) - - @pytest.mark.parametrize( "batch_size,num_heads,seq_len,per_head_dim", [ diff --git a/axlearn/common/flash_attention/gpu_decoding.py b/axlearn/common/flash_attention/gpu_decoding.py index c8beee870..557c00b51 100644 --- a/axlearn/common/flash_attention/gpu_decoding.py +++ b/axlearn/common/flash_attention/gpu_decoding.py @@ -75,6 +75,11 @@ def _attn_forward_kernel( ): _, head_dim = q_ref.shape split_k_seq_len, _ = k_ref.shape + precision = ( + lax.Precision.HIGHEST + if jnp.float32 in (q_ref.dtype, k_ref.dtype, v_ref.dtype) + else lax.Precision.DEFAULT + ) prog_i, prog_j = pl.program_id(1), pl.program_id(2) q_mask = (block_h * prog_i + jnp.arange(block_h) < qhead_per_kvhead)[:, None] @@ -98,7 +103,7 @@ def body(start_k, carry): def compute(): curr_k_slice = pl.ds(start_k * block_k, block_k) k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=mask[:, None], other=0.0) - qk = pl.dot(q, k.T, allow_tf32=False) # [block_h, block_k] + qk = pl.dot(q, k.T, precision=precision) # [block_h, block_k] if bias_ref is not None: qk += pl.load( bias_ref, (slice(None), curr_k_slice), mask=mask[None, :], other=0.0 @@ -115,7 +120,7 @@ def compute(): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=mask[:, None], other=0.0) - o_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=False) + o_curr = pl.dot(s_curr.astype(v.dtype), v, precision=precision) # Flash2 unscaled_o. o_next = correction[:, None] * o_prev + o_curr diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index cf2d66f31..7d0a8fdaf 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -18,6 +18,8 @@ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" # pylint: enable=wrong-import-position +from functools import partial + import jax import jax.numpy as jnp import pytest @@ -750,15 +752,6 @@ def test_extend_step( mask=mask, inference=True, ) - tpu_block_size = test_layer.config.tpu_block_size - # pylint: disable-next=protected-access - if test_layer._backend() == "tpu" and seq_len % tpu_block_size != 0: - pytest.skip( - f"Sequence length {seq_len} is not divisible by configured block size for " - f"tpu {test_layer.config.tpu_block_size }. " - "This was unsupported (and the test failed) even prior to adding " - "this skip statement." - ) # Prepare inputs query = jax.random.normal( @@ -774,7 +767,7 @@ def test_extend_step( dtype=dtype, ) kv_state = None - return_aux = {"probs"} + return_aux = None inputs = dict( query=query, @@ -839,6 +832,18 @@ def test_extend_step( decoder_output = jnp.zeros(shape=[seq_len, batch, hidden_dim]).astype(dtype) ref_decoder_output = jnp.zeros(shape=[seq_len, batch, hidden_dim]).astype(dtype) + + @partial(jax.jit, static_argnames=["layer"]) + def extend_one_step(params, inputs, layer): + return F( + layer, + state=params, + is_training=False, + prng_key=jax.random.PRNGKey(5), + inputs=inputs, + method="extend_step", + ) + for t in range(seq_len): cur_query = jnp.expand_dims(query[:, t, :], axis=1) inputs["query"] = cur_query @@ -851,27 +856,13 @@ def test_extend_step( ref_inputs["attention_logit_biases"] = jnp.expand_dims( causal_bias[:, :, t, :], axis=2 ) - ref_extend_step_outputs, _ = F( - ref_layer, - state=params, - is_training=False, - prng_key=jax.random.PRNGKey(5), - inputs=ref_inputs, - method="extend_step", - ) + ref_extend_step_outputs, _ = extend_one_step(params, ref_inputs, ref_layer) ref_inputs["cached_states"] = ref_extend_step_outputs[0] ref_decoder_output = ref_decoder_output.at[t].set( jnp.squeeze(ref_extend_step_outputs[1].data, axis=1) ) - extend_step_outputs, _ = F( - test_layer, - state=params, - is_training=False, - prng_key=jax.random.PRNGKey(5), - inputs=inputs, - method="extend_step", - ) + extend_step_outputs, _ = extend_one_step(params, inputs, test_layer) inputs["cached_states"] = extend_step_outputs[0] decoder_output = decoder_output.at[t].set( jnp.squeeze(extend_step_outputs[1].data, axis=1) @@ -880,6 +871,7 @@ def test_extend_step( self.assertNestedAllClose( decoder_output[t], ref_decoder_output[t], + rtol=0.02, atol=2e-2, ) @@ -889,16 +881,19 @@ def test_extend_step( self.assertNestedAllClose( ref_out.data, ref_decoder_out_transposed, + rtol=0.02, atol=2e-2, ) self.assertNestedAllClose( decoder_out_transposed, ref_decoder_out_transposed, + rtol=0.02, atol=2e-2, ) self.assertNestedAllClose( ref_out.data, test_out.data, + rtol=0.02, atol=2e-2, ) jax.clear_caches() diff --git a/axlearn/common/flash_attention/tpu_attention_benchmark.py b/axlearn/common/flash_attention/tpu_attention_benchmark.py index 5a9dbc106..76161e3a6 100644 --- a/axlearn/common/flash_attention/tpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/tpu_attention_benchmark.py @@ -39,7 +39,7 @@ TensorAttentionBias, sliding_window_causal_mask, ) -from axlearn.common.flash_attention.utils import flash_attention_implementation, mha_reference +from axlearn.common.flash_attention.utils import flash_attention_implementation _BENCHMARK_CONFIGS = { "1.2b": dict( @@ -51,7 +51,8 @@ per_head_dim=128, ), "29.6b": dict( - num_heads=56, + num_heads=8, + num_kv_heads=1, per_head_dim=128, ), "65.2b": dict( @@ -91,75 +92,77 @@ def _benchmark( block_size: int, num_heads: int, per_head_dim: int, + num_kv_heads: Optional[int] = None, + is_decoding: bool = False, causal: bool = True, use_bias: bool = False, - use_segment_ids: bool = False, sliding_window_size: Optional[int] = None, ): """Benchmarks TPU FlashAttention vs reference impl.""" - k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) - q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.bfloat16) - k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.bfloat16) - v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.bfloat16) - - bias = None - if use_bias: - bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=jnp.bfloat16) - segment_ids = None - if use_segment_ids: - segment_ids = jnp.cumsum( - jax.random.bernoulli(k5, shape=(batch_size, seq_len)).astype(jnp.int32), axis=1 - ) + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) + if num_kv_heads is None: + num_kv_heads = num_heads + q_seq_len = 1 if is_decoding else seq_len + q = jax.random.normal(k1, (batch_size, q_seq_len, num_heads, per_head_dim), dtype=jnp.bfloat16) + k = jax.random.normal(k2, (batch_size, seq_len, num_kv_heads, per_head_dim), dtype=jnp.bfloat16) + v = jax.random.normal(k3, (batch_size, seq_len, num_kv_heads, per_head_dim), dtype=jnp.bfloat16) softmax_scale = per_head_dim**-0.5 - ref_fwd_time = _time_call( - lambda: mha_reference( - q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale - ) - ) - - grad_fn = jax.jit( - jax.grad( - lambda q, k, v, b, s: mha_reference( - q, k, v, b, s, causal=causal, softmax_scale=softmax_scale - ).mean(), - argnums=(0, 1, 2), - ) - ) - ref_bwd_time = _time_call(lambda: grad_fn(q, k, v, bias, segment_ids)[0]) - - mask = None + mask = [] + if is_decoding: + target_positions = jnp.asarray([seq_len - 1])[None] + else: + target_positions = jnp.arange(seq_len)[None] if causal and sliding_window_size is None: - mask = CausalAttentionBias( - target_positions=jnp.arange(seq_len)[None], - source_positions=jnp.arange(seq_len)[None], + mask.append( + CausalAttentionBias( + target_positions=target_positions, + source_positions=jnp.arange(seq_len)[None], + ) ) elif causal: - mask = SlidingWindowAttentionBias( - sliding_window_causal_mask(sliding_window_size), - sliding_window_size=sliding_window_size, - target_positions=jnp.arange(seq_len)[None], - source_positions=jnp.arange(seq_len)[None], + mask.append( + SlidingWindowAttentionBias( + sliding_window_causal_mask(sliding_window_size), + sliding_window_size=sliding_window_size, + target_positions=target_positions, + source_positions=jnp.arange(seq_len)[None], + ) ) if use_bias: - bias = CompositeAttentionBias([mask, TensorAttentionBias(bias)]) - else: - bias = CompositeAttentionBias([mask]) + mask.append( + TensorAttentionBias( + jax.random.normal( + k4, (batch_size, num_heads, q_seq_len, seq_len), dtype=jnp.bfloat16 + ) + ) + ) + bias = CompositeAttentionBias(mask) # Get fwd & bwd timing information when softmax scaling applied before calling the kernel. + ref_mha_impl = flash_attention_implementation( + "xla", softmax_scale=softmax_scale, block_size=block_size, is_decoding=is_decoding + ) mha_impl = flash_attention_implementation( - "tpu", softmax_scale=softmax_scale, block_size=block_size + "tpu", softmax_scale=softmax_scale, block_size=block_size, is_decoding=is_decoding ) + ref_fwd_time = _time_call(lambda: ref_mha_impl(q, k, v, bias)) flash_fwd_time = _time_call(lambda: mha_impl(q, k, v, bias)) - flash_grad_fn = jax.jit( - jax.grad(lambda q, k, v, b: mha_impl(q, k, v, b).mean(), argnums=(0, 1, 2)) - ) - flash_bwd_time = _time_call(lambda: flash_grad_fn(q, k, v, bias)[0]) + if not is_decoding: + flash_grad_fn = jax.jit( + jax.grad(lambda q, k, v, b: ref_mha_impl(q, k, v, b).mean(), argnums=(0, 1, 2)) + ) + ref_bwd_time = _time_call(lambda: flash_grad_fn(q, k, v, bias)[0]) + flash_grad_fn = jax.jit( + jax.grad(lambda q, k, v, b: mha_impl(q, k, v, b).mean(), argnums=(0, 1, 2)) + ) + flash_bwd_time = _time_call(lambda: flash_grad_fn(q, k, v, bias)[0]) print(f"ref_fwd:{ref_fwd_time:.4f}s, flash_fwd:{flash_fwd_time:.4f}s") - print(f"ref_bwd:{ref_bwd_time:.4f}s, flash_bwd:{flash_bwd_time:.4f}s\n") + if not is_decoding: + print(f"ref_bwd:{ref_bwd_time:.4f}s, flash_bwd:{flash_bwd_time:.4f}s\n") if __name__ == "__main__": @@ -169,8 +172,9 @@ def _benchmark( print(f"Benchmarking attention representative of {name} model layer on {device_kind}.") _benchmark( batch_size=2, - seq_len=1024 * 8, + seq_len=1024 * 128, block_size=4 * 128, - sliding_window_size=1024, + sliding_window_size=4096, + is_decoding=True, **cfg, ) diff --git a/axlearn/common/flash_attention/tpu_decoding.py b/axlearn/common/flash_attention/tpu_decoding.py new file mode 100644 index 000000000..8d5f1c59a --- /dev/null +++ b/axlearn/common/flash_attention/tpu_decoding.py @@ -0,0 +1,284 @@ +# Copyright © 2025 Apple Inc. +"""Implements TPU decoding. + +Unlike GPU, TPU blocks are sequential (except when there're two cores). Therefore, unlike GPU +decoding, there's no need to parallelize over the KV sequence length. As the result, it works +very similar to full attention. The grid dimensions are +(batch_size, num_kv_heads, num_kv_blocks). + +The main reason to use the kernel is that it can take advantage of the fact that most KV blocks +are padding in practical decoding scenarios. Also, it can take advantage of sparsity in +`mask_fn`. + +Performance note: +1. When kv_seq_len == padded_kv_seq_len: + This kernels performs similarly to non-fused (i.e. XLA) attention, or within 10% slower. +2. When kv_seq_len < padded_kv_seq_len or `mask_fn` has sparsity: + This kernel provides speed up roughly equal to padded_kv_seq_len / kv_seq_len or number + of masked kv blocks / total kv blocks. + +The main reason why non-fused attention is faster when kv are not padded is that the non-fused +matmuls can flatten the non-head dimensions, thus having larger non-contracting dimensions. +This leads to have better utilization of the matrix and memory units. +""" +from functools import partial +from typing import Optional + +import jax +import jax.numpy as jnp +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +from axlearn.common.attention_bias import NEG_INF, MaskFn +from axlearn.common.flash_attention.common import build_mask, query_iterator_indices +from axlearn.common.utils import Tensor + + +def _tpu_decoding_kernel( + # Scalars. + kv_seq_len_ref, + kv_block_offset, + kv_block_offset_size, + # Inputs. + q_ref, + k_ref, + v_ref, + b_ref, + # Outputs. + o_ref, + # Scatch. + m_i, + l_i, + o_scratch, + # Compile time args. + softmax_scale: float, + mask_fn: Optional[MaskFn], +): + batch_index = pl.program_id(0) + non_empty_kv_block_index = pl.program_id(2) + _, block_k = k_ref.shape + precision = ( + lax.Precision.HIGHEST if jnp.float32 in (q_ref.dtype, k_ref.dtype, v_ref.dtype) else None + ) + + # o is the buffer where we accumulate the output on sram. + # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. + @pl.when(non_empty_kv_block_index == 0) + def init(): + m_i[...] = jnp.full_like(m_i, NEG_INF) + l_i[...] = jnp.zeros_like(l_i) + o_scratch[...] = jnp.zeros_like(o_scratch) + + # Note: on CPU interpret mode, pl.program_id() cannot appear in functions decorated by + # pl.when. + kv_offset = kv_block_offset[batch_index, non_empty_kv_block_index] * block_k + kv_seq_len = kv_seq_len_ref[batch_index] + num_non_empty_kv_blocks = kv_block_offset_size[batch_index] + + # Different batch may have different number of-non empty kv blocks. + @pl.when(non_empty_kv_block_index < num_non_empty_kv_blocks) + def compute(): + q = q_ref[...] + k = k_ref[...] + qk = pl.dot(q, k, precision=precision) + if softmax_scale != 1.0: + qk *= softmax_scale + if b_ref is not None: + qk += b_ref[...] + qk = jnp.maximum(qk, NEG_INF) + # Note: Pallas TPU requires the use of lax.broadcasted_iota instead of jnp.arange as only + # 2D range is supported. + block_kv_indices = kv_offset + lax.broadcasted_iota(jnp.int32, qk.shape, 1) + kv_mask = block_kv_indices < kv_seq_len + if mask_fn is not None: + kv_mask = kv_mask & mask_fn(kv_seq_len - 1, block_kv_indices) + qk = jnp.where(kv_mask, qk, NEG_INF) + + m_prev = m_i[...] + l_prev = l_i[...] + o_prev = o_scratch[...] + + # We need to make sure each array has two dims, or we get TPU Mosaic lowering errors. + m_curr = qk.max(axis=-1, keepdims=True) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + # Use m_next instead of m_curr to avoid a correction on l_curr. + s_curr = jnp.exp(qk - m_next) + l_curr = s_curr.sum(axis=-1, keepdims=True) + l_next = l_prev_corr + l_curr + o_prev_corr = correction * o_prev + v = v_ref[...] + o_curr = pl.dot(s_curr.astype(v.dtype), v.T, precision=precision) + + o_next = o_prev_corr + o_curr + + m_i[...] = m_next + l_i[...] = l_next + o_scratch[...] = o_next + + @pl.when(non_empty_kv_block_index == num_non_empty_kv_blocks - 1) + def final(): + # We keep an unscaled version of o during the scan over kv_seq_len. Scaling it + # by the last l_i gives us the correct final output. See section 3.1.1 in the + # FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691. + o_ref[...] = (o_scratch[...] / l_i[...]).astype(o_ref.dtype) + + +@partial( + jax.jit, + static_argnames=[ + "softmax_scale", + "mask_fn", + "block_size", + "interpret", + ], +) +def tpu_decoding( + q: Tensor, + k: Tensor, + v: Tensor, + kv_seq_len: Optional[Tensor], + bias: Optional[Tensor] = None, + *, + softmax_scale: float = 1.0, + mask_fn: Optional[MaskFn] = None, + block_size: int = 512, + interpret: bool = False, +): + """Implements TPU decoding with GQA support. + + The functionality of TPU decoding is similar to GPU FlashDecoding, except that + padded_kv_seq_len must be divisible by block_size. + + Args: + q: Tensor of shape [batch_size, 1, num_q_heads, head_dim]. + k: Tensor of shape [batch_size, padded_kv_seq_len, num_kv_heads, head_dim]. + v: Tensor of shape [batch_size, padded_kv_seq_len, num_kv_heads, head_dim]. + kv_seq_len: Tensor that can broadcast to [batch_size], indicating the actual kv sequence + length for each sequence in the batch. If None, assumes k and v are not padded in the + sequence dimension. + bias: Tensor that can broadcast to [batch_size, num_q_heads, 1, padded_kv_seq_len]. + Defaults to None. + softmax_scale: Softmax scale. + mask_fn: Mask function to use. Preferred over bias. + block_size: Block dimension along the sequence dim. Defaults to 512. + + Returns: + A tensor with the same shape and dtype as q. + + Raises: + ValueError if the shape of qkv doesn't satisfy assumptions. + """ + if q.shape[1] != 1: + raise ValueError("Multi-step decoding is not supported yet.") + # Pallas TPU doesn't support pl.load(..., mask=xxx), so we kv len must divide block size. + # However, we can reduce the block size to support the case where + # padded_kv_seq_len < block_size. + block_size = min(block_size, k.shape[1]) + if k.shape[1] % block_size != 0: + raise ValueError(f"KV sequence length {k.shape[1]} must be divisible by {block_size=}.") + orig_q_shape = q.shape + q_seq_len = q.shape[1] + block_kv = block_size + q = q.squeeze(1) + # Convert to bnhs which is the native shape of KV in the kv cache. These two transposes should + # be elided by the compiler. See `BaseQKVLinear.init_states` from attention.py. + k = jnp.einsum("bsnh->bnhs", k) + v = jnp.einsum("bsnh->bnhs", v) + bs, kv_heads, head_dim, padded_kv_seq_len = k.shape + if kv_seq_len is not None: + kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len), (bs,)) + else: + kv_seq_len = jnp.full((bs,), padded_kv_seq_len, dtype=jnp.int32) + + # Computes a full block map num_kv_blocks * num_kv_blocks. + # Use a padding to ensure padding blocks aren't counted towards `kv_block_offset_size`. + padding = -1 + with jax.ensure_compile_time_eval(): + if mask_fn is not None: + bool_mask = build_mask( + mask_fn, + q_seq_len=padded_kv_seq_len, + kv_seq_len=padded_kv_seq_len, + block_q=block_size, + block_k=block_size, + ) + offset, _ = query_iterator_indices(bool_mask, padding=padding) + else: + padded_num_kv_blocks = pl.cdiv(padded_kv_seq_len, block_size) + offset = lax.broadcasted_iota( + jnp.int32, (padded_num_kv_blocks, padded_num_kv_blocks), 1 + ) + + # Dynamically slice the rows according to the query position (which is kv_seq_len - 1). + kv_block_offset = offset[(kv_seq_len - 1) // block_size] + # Count the number of blocks with position < kv_seq_len. + kv_block_offset_size = jnp.count_nonzero( + (kv_block_offset != padding) & (kv_block_offset * block_size < kv_seq_len[:, None]), axis=1 + ) + # Replace padding with the last valid kv block's index. See + # https://docs.jax.dev/en/latest/pallas/tpu/sparse.html#sparse-access-patterns-on-dense-data + kv_block_offset = jnp.where( + kv_block_offset == padding, kv_block_offset.max(axis=1, keepdims=True), kv_block_offset + ) + + q = q.reshape(bs, kv_heads, -1, head_dim) + q_seq_head = q.shape[-2] # = q_seq_len * num_q_heads_per_kv_head + assert q_seq_head <= 512 + + def kv_index_map( + batch_idx, head_idx, kv_block_idx, kv_seq_len, kv_block_offset, kv_block_offset_size + ): + del kv_seq_len, kv_block_offset_size + return (batch_idx, head_idx, 0, kv_block_offset[batch_idx, kv_block_idx]) + + q_spec = pl.BlockSpec((None, None, q_seq_head, head_dim), lambda b, h, j, *args: (b, h, 0, 0)) + kv_spec = pl.BlockSpec((None, None, head_dim, block_kv), kv_index_map) + bias_spec = None + if bias is not None: + if bias.shape[0] == 1 and bias.shape[1] == 1: + + def bias_index_map( + batch_idx, + head_idx, + kv_block_idx, + kv_seq_len, + kv_block_offset, + kv_block_offset_size, + ): + del head_idx, kv_seq_len, kv_block_offset_size + return (0, 0, 0, kv_block_offset[batch_idx, kv_block_idx]) + + bias_spec = pl.BlockSpec((None, None, q_seq_len, block_kv), bias_index_map) + else: + bias = bias.reshape(bs, kv_heads, q_seq_head, padded_kv_seq_len) + bias_spec = pl.BlockSpec((None, None, q_seq_head, block_kv), kv_index_map) + + out: Tensor = pl.pallas_call( + partial(_tpu_decoding_kernel, softmax_scale=softmax_scale, mask_fn=mask_fn), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=3, + in_specs=[ + q_spec, + kv_spec, + kv_spec, + bias_spec, + ], + out_specs=q_spec, + scratch_shapes=[ + # VMEM requires 2D arrays. + pltpu.VMEM((q_seq_head, 1), jnp.float32), + pltpu.VMEM((q_seq_head, 1), jnp.float32), + pltpu.VMEM((q_seq_head, head_dim), jnp.float32), + ], + grid=(bs, kv_heads, kv_block_offset_size.max()), + ), + out_shape=jax.ShapeDtypeStruct(q.shape, q.dtype), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary") + ), + interpret=interpret, + )(kv_seq_len, kv_block_offset, kv_block_offset_size, q, k, v, bias) + return out.reshape(orig_q_shape) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index bc9204f62..561faf7c2 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from absl import logging -from axlearn.common.attention import softmax_with_biases +from axlearn.common.attention import compute_gqa_context, compute_gqa_logits, softmax_with_biases from axlearn.common.attention_bias import ( NEG_INF, BaseAttentionBias, @@ -16,19 +16,18 @@ CompositeAttentionBias, MaskFnAttentionBias, SegmentIdAttentionBias, - TensorAttentionBias, split, ) from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention from axlearn.common.flash_attention.gpu_decoding import flash_decoding from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention +from axlearn.common.flash_attention.tpu_decoding import tpu_decoding from axlearn.common.layers import dropout from axlearn.common.utils import Tensor @functools.partial(jax.jit, static_argnames=["causal", "softmax_scale", "dropout_rate"]) -@jax.default_matmul_precision("bfloat16") def mha_reference( q: Tensor, k: Tensor, @@ -42,12 +41,12 @@ def mha_reference( dropout_rate: float = 0.0, dropout_mask: Optional[Tensor] = None, ) -> Tensor: - """Reference multi-headed attention implementation. + """Reference multi-headed attention implementation with GQA optimization. Args: q: query tensor with shape [batch_size, seq_len, num_heads, per_head_dim] - k: key tensor with shape [batch_size, seq_len, num_heads, per_head_dim] - v: value tensor with shape [batch_size, seq_len, num_heads, per_head_dim] + k: key tensor with shape [batch_size, seq_len, num_kv_heads, per_head_dim] + v: value tensor with shape [batch_size, seq_len, num_kv_heads, per_head_dim] bias: bias tensor with a shape that can broadcast to [batch_size, num_heads, seq_len, seq_len], e.g. [1, 1, seq_len, seq_len]. segment_ids: segment ids tensor with shape [batch_size, seq_len]. @@ -60,9 +59,10 @@ def mha_reference( """ # We apply the scale factor before the attention biases. q *= softmax_scale - logits = jnp.einsum("btnh,bsnh->bnts", q, k) + logits = compute_gqa_logits(q, k) - # Check if we need to build a segment id mask. + # TODO(hanzhi-zhou): Remove segment ids and causal here. Refactor unit tests that use them. + # We can construct masks directly. if segment_ids is not None: assert segment_ids.ndim == 2 # shape [batch_size, seq_len] target_segment_ids = jnp.expand_dims(segment_ids, -1) @@ -84,8 +84,7 @@ def mha_reference( if dropout_rate > 0: probs = dropout(probs, prng_key=prng_key, rate=dropout_rate, mask=dropout_mask) - context = jnp.einsum("bnts,bsnh->btnh", probs, v).astype(v.dtype) - return context + return compute_gqa_context(probs, v) def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor: @@ -143,19 +142,20 @@ def jit_attn( *, backend: str = backend, ) -> Tensor: - # Fall back to plain MHA implementation when the seq_len is not be divisible by - # block size. - is_single_step_gpu_decoding = is_decoding and query.shape[1] == 1 and backend == "gpu" - # For non-GPU decoding, fall back to non-flash implementation and merge all biases - # into a dense floating point bias tensor since that implementation does not - # support target_positions. - if not is_single_step_gpu_decoding: + is_single_step_decoding = is_decoding and query.shape[1] == 1 + # TODO(hanzhi-zhou): Support multi-step GPU and TPU decoding. + if not is_single_step_decoding: if is_decoding: - # TODO(senyut): Support TPU decoding. + # If multi-step decoding, fall back to non-flash implementation. backend = "xla" - bias = TensorAttentionBias(bias.value()) + # Fall back to plain MHA implementation when the seq_len is not be divisible by + # block size. + # FIXME(hanzhi-zhou): This dispatch is not optimal. Backends like cuDNN have more + # relaxed constraints on the input shapes. if query.shape[1] % block_size != 0: backend = "xla" + if is_single_step_decoding and backend not in ("gpu", "tpu", "cpu"): + backend = "xla" if dropout_rate != 0.0 and backend not in ("gpu", "xla", "cpu"): raise NotImplementedError("Dropout is only implemented for GPU, CPU and XLA.") @@ -180,14 +180,12 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: return segment_ids.segment_ids if backend == "gpu": - # TODO(hanzhi-zhou): supports small q sequence length for future use cases such as - # speculative decoding. - if is_single_step_gpu_decoding: + if is_single_step_decoding: # Decoding case. We should not repeat kv heads to match q heads for FlashDecoding. # Note: decoding is always causal. Discard the causal mask if present. mask, explicit_bias = split(bias, MaskFnAttentionBias) if mask is None or mask.target_positions is None: - raise RuntimeError("Cannot retrive MaskFnAttentionBias or target_positions.") + raise RuntimeError("Cannot retrieve MaskFnAttentionBias or target_positions.") mask_fn = mask.mask query_time_step = mask.target_positions[:, -1] kv_seq_len = query_time_step + 1 @@ -262,6 +260,26 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend == "tpu": + if is_single_step_decoding: + mask, explicit_bias = split(bias, MaskFnAttentionBias) + if mask is None or mask.target_positions is None: + raise RuntimeError("Cannot retrieve MaskFnAttentionBias or target_positions.") + mask_fn = mask.mask + logging.info("Using mask_fn=%s for FlashDecoding.", mask_fn) + query_time_step = mask.target_positions[:, -1] + kv_seq_len = query_time_step + 1 + return tpu_decoding( + query, + key, + value, + bias=explicit_bias.value(), + mask_fn=mask_fn, + kv_seq_len=kv_seq_len, + softmax_scale=softmax_scale, + interpret=_interpret(backend), + block_size=block_size, + ) + # TODO(dhwang2): splash attention supports GQA natively, so don't repeat it. # https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934 key = _repeat_kv_heads(query.shape[2], key) @@ -313,25 +331,16 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend in ("cpu", "xla"): - key = _repeat_kv_heads(query.shape[2], key) - value = _repeat_kv_heads(query.shape[2], value) if backend == "cpu": logging.info("Flash attention CPU backend is for testing only.") logging.info("Flash attention falling back using plain MHA implementation") - # `causal` is supported. - # `segment_ids` is supported. - causal, segment_ids, explicit_bias = split( - bias, CausalAttentionBias, SegmentIdAttentionBias - ) return mha_reference( query, key, value, - bias=explicit_bias.value(), - segment_ids=get_segment_ids(segment_ids), + bias=bias.value(), prng_key=prng_key, - causal=causal.has_value(), softmax_scale=softmax_scale, dropout_rate=dropout_rate, ) diff --git a/axlearn/common/test_utils.py b/axlearn/common/test_utils.py index 126602b3e..e828d0394 100644 --- a/axlearn/common/test_utils.py +++ b/axlearn/common/test_utils.py @@ -13,7 +13,7 @@ from collections.abc import Iterator, Sequence from functools import partial from tempfile import mkdtemp -from typing import Any, Optional, Protocol, TypeVar, Union +from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from unittest.mock import patch import jax @@ -135,6 +135,11 @@ def __call__(self, src: Any, *, dst_layer: BaseLayer) -> NestedTensor: """Converts parameters from `src` to parameters for `dst_layer`.""" +class Tolerance(NamedTuple): + rtol: float = 0.001 + atol: float = 0.001 + + class TestCase(parameterized.TestCase): """Base test class.""" @@ -249,6 +254,45 @@ def assertNestedEqual(self, a, b): if hasattr(a_value, "dtype"): self.assertEqual(a_value.dtype, b_value.dtype) + def assertAllCloseWithOutliers(self, actual, desired, *, tolerance_map: dict[float, Tolerance]): + """Like np.testing.assert_allclose, but allows outlier percentiles to be specified. + + `tolerance_map` is mapping of percentile values (between 0 and 1) to `Tolerance` objects. + Each entry defines the acceptable tolerance for a certain percentile of elements in the + difference `abs(actual - desired)`. The specified tolerance should be met within the given + percentile of total elements in `actual` or `desired`. + + Example: + ```python + self.assertAllCloseWithOutliers(x, y, tolerance_map={ + 1.0: Tolerance(atol=0.2), + 0.95: Tolerance(atol=0.05), + }) + ``` + This example asserts 100% elements of `abs(x - y)` should be within atol=0.2, and 95% + elements of `abs(x - y)` should be within atol=0.05. + """ + assert len(tolerance_map) > 0 + self.assertEqual(actual.shape, desired.shape) + self.assertEqual(actual.dtype, desired.dtype) + actual = actual.astype(np.float32) + desired = desired.astype(np.float32) + diff = np.abs(actual - desired) + for percentile, tol in tolerance_map.items(): + percentile = 1 - percentile + tolerance = tol.atol + tol.rtol * np.abs(desired) + expected_num_ele = round(diff.size * percentile) + actual_num_ele = np.count_nonzero(diff > tolerance) + actual_percent = actual_num_ele / diff.size + self.assertLessEqual( + actual_num_ele, + expected_num_ele, + msg=f"Expected the number of elements over {tol} to be less than {percentile:.3%}" + f" of total elements (or {expected_num_ele}), but got {actual_percent:.3%} " + f"(or {actual_num_ele}). These differences are {diff[diff > tolerance]}. " + f"Max difference = {diff.max()}", + ) + # TODO(markblee): Move this to axlearn/experiments/test_utils.py, where it's used. class TrainerConfigTestCase(TestCase):