Skip to content

Commit

Permalink
Refactor KV Cache Out of QKVLinear
Browse files Browse the repository at this point in the history
Currently, **`QKVLinear` is overly complex** because it handles **both QKV
computation and KV cache management**.

Although **QKVLinear’s role and KV cache strategy should be independent**, the
current implementation **forces QKVLinear to manage KV cache**, making it
necessary to **carefully maintain every QKVLinear subclass** (`FusedQKV`,
`GroupedQKV`, `RoPEQKV`, etc.) to ensure they correctly handle KV cache.

**Key Changes in This PR**
This PR **removes KV cache logic from `QKVLinear`**, turning it into a **pure
`forward`-only class** (such as QKV proj and RoPE) that **no longer needs to
handle decoding**.

Instead, **Attention now owns the KV cache directly**, making it **more
flexible for future KV cache strategies**.

Currently, **`QKVLinear` supports only one KV cache behavior**, which maintains
a **fixed max length**. However, in the near future, we will introduce more
**KV cache strategies**, such as:

- **Sliding Window Attention** → Requires a **sliding window KV cache**.
- **Sparse Attention** → Needs a KV cache that **dynamically selects sparse KV**
  (similar to DeepSeek). https://arxiv.org/abs/2502.11089

**Implementation Details**
A key aspect of this refactor is **how query positions and key positions are
generated**.
Previously, the related logic was **scattered across multiple places**, but
now, **positions are computed in a single place**:

- **Query positions** → Must be determined **before RoPE** since RoPE requires
  them. The **same query positions** are then **reused throughout the code**.
- **Key positions** → Only the **KV cache layer** can determine them
  **accurately** since **KV cache strategies** directly affect key positions.
  So, **KV cache is now responsible for generating key positions**.

In addition, **`KVState` now carries both KV values and key positions**.
  • Loading branch information
ds-hwang committed Mar 8, 2025
1 parent b1e7b37 commit 624bec8
Show file tree
Hide file tree
Showing 89 changed files with 555 additions and 685 deletions.
442 changes: 224 additions & 218 deletions axlearn/common/attention.py

Large diffs are not rendered by default.

268 changes: 87 additions & 181 deletions axlearn/common/attention_test.py

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from jax.experimental import mesh_utils
from jax.sharding import Mesh

from axlearn.common.attention import Dropout, GroupedQKVLinear, GroupedQueryAttention, QKVLinear
from axlearn.common.attention import (
Dropout,
GroupedQKVLinear,
GroupedQueryAttention,
KVCache,
QKVLinear,
)
from axlearn.common.attention_bias import (
CausalAttentionBias,
CompositeAttentionBias,
Expand Down Expand Up @@ -118,7 +124,7 @@ def _prepare_layers(
ref_cfg = GroupedQueryAttention.default_config().set(**kwargs)

if inference:
ref_cfg.input_linear.set(dtype=jnp.bfloat16)
ref_cfg.set(kv_cache=KVCache.default_config().set(cache_dtype=jnp.bfloat16))
test_cfg = (
FlashAttention.default_config()
.set(**kwargs)
Expand All @@ -129,7 +135,7 @@ def _prepare_layers(
)
)
if inference:
test_cfg.input_linear.set(dtype=jnp.bfloat16)
test_cfg.set(kv_cache=KVCache.default_config().set(cache_dtype=jnp.bfloat16))

ref_cfg.set(mask=mask)
test_cfg.set(mask=mask)
Expand Down Expand Up @@ -820,8 +826,8 @@ def test_extend_step(
self.assertIsNone(initial_output)
self.assertIsNone(ref_inital_output)
for k in ["key", "value"]:
self.assertEqual(ref_initial_state["i_proj"][k].dtype, dtype)
self.assertEqual(initial_state["i_proj"][k].dtype, dtype)
self.assertEqual(ref_initial_state["kv_cache"][k].dtype, dtype)
self.assertEqual(initial_state["kv_cache"][k].dtype, dtype)

# Prepare decoding inputs.
inputs = dict(
Expand Down
1 change: 0 additions & 1 deletion axlearn/common/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,6 @@ def __init__(self, cfg: Config, *, parent: Module):
qkv_proj_cfg.value_dim = cfg.value_dim
qkv_proj_cfg.num_heads = cfg.num_heads
qkv_proj_cfg.per_head_dim = cfg.per_head_dim
qkv_proj_cfg.cache_dtype = cfg.cache_dtype
self._add_child("layer", qkv_proj_cfg)
self._add_child(
"adapter",
Expand Down
181 changes: 53 additions & 128 deletions axlearn/common/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from axlearn.common.attention import (
FusedQKVLinear,
KVState,
MultiheadAttention,
MultiheadOutputLinear,
QKVLinear,
QLinear,
RoFormerQKVLinear,
)
from axlearn.common.attention_bias import CausalAttentionBias
from axlearn.common.layers import Linear
from axlearn.common.lora import (
LoraFusedQKVAdapter,
Expand All @@ -33,7 +34,7 @@
from axlearn.common.module import functional as F
from axlearn.common.param_converter import as_torch_tensor
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.utils import Tensor


class LoraLinearTest(TestCase):
Expand Down Expand Up @@ -166,8 +167,12 @@ def test_forward(self, ref_layer_cfg):
external_value = jax.random.normal(
jax.random.PRNGKey(90), (batch_size, seq_len, num_heads, per_head_dim)
)
key_positions = jnp.arange(seq_len)[None]
inputs = dict(
query=inputs, kv_state=KVState(k_proj=external_key, v_proj=external_value)
query=inputs,
kv_state=KVState(
k_proj=external_key, v_proj=external_value, key_positions=key_positions
),
)
else:
inputs = (inputs,)
Expand Down Expand Up @@ -216,156 +221,76 @@ def test_forward(self, ref_layer_cfg):
# Expect the same output due to zero initialization of one of the LoRA weights.
assert_allclose(outputs, ref_outputs)

@parameterized.product(
layer=(
FusedQKVLinear.default_config(),
RoFormerQKVLinear.default_config().set(
rotary_value=False, input_linear=FusedQKVLinear.default_config()
),
),
@parameterized.parameters(
(QKVLinear.default_config(),),
(FusedQKVLinear.default_config(),),
)
def test_extend_step(self, layer):
model_dim = 16
def test_extend_step(self, ref_layer_cfg):
model_dim = 6
num_heads = 2
per_head_dim = 4 # change this to 4 to adapt the need of RoPE.
seq_len = 4
batch_size = 2
lora_linear = LoraFusedQKVLinear.default_config().set(layer=ref_layer_cfg)
rank = 2
alpha = 4
enable_lora = dict(query=True, key=False, value=True)
num_enabled = sum(enable_lora.values())
inputs = jax.random.normal(jax.random.PRNGKey(456), (batch_size, seq_len, model_dim))

layer_cfg = LoraFusedQKVLinear.default_config().set(
layer=layer,
lora_linear.adapter.set(rank=rank, alpha=alpha, enable_lora=enable_lora)
cfg = MultiheadAttention.default_config().set(
name="test",
input_linear=lora_linear,
query_dim=model_dim,
key_dim=model_dim,
value_dim=model_dim,
num_heads=num_heads,
per_head_dim=per_head_dim,
)
layer_cfg.adapter.set(rank=rank, alpha=alpha, enable_lora=enable_lora)
layer = layer_cfg.instantiate(parent=None)
state = layer.initialize_parameters_recursively(
prng_key=jax.random.PRNGKey(123), prebuilt=None
)
state["adapter"]["lora_up"]["weight"] = jax.random.normal(
jax.random.PRNGKey(1), (num_enabled, rank, num_heads, per_head_dim)
)
outputs, _ = jax.jit(partial(F, layer, is_training=False))(
state=state,
prng_key=jax.random.PRNGKey(456),
inputs=(inputs,),
mask=CausalAttentionBias.default_config(),
)
q_proj, k_proj, v_proj = outputs
forward_outputs = jnp.stack([q_proj, k_proj, v_proj])
layer = cfg.instantiate(parent=None)

initial_cache_state, init_output = layer.init_states(
time_step=None,
query=TensorSpec([batch_size, seq_len], dtype=q_proj.dtype),
)
self.assertIsNone(init_output)
# Initialize layer parameters.
prng_key = jax.random.PRNGKey(123)
prng_key, init_key = jax.random.split(prng_key)
layer_params = layer.initialize_parameters_recursively(init_key)

decoder_inputs = dict(cached_states=initial_cache_state)
decoder_outputs = jnp.zeros(shape=[seq_len, 3, batch_size, num_heads, per_head_dim])
for t in range(seq_len):
decoder_inputs["query"] = jnp.expand_dims(inputs[:, t, :], axis=1)
(updated_states, outputs), _ = F(
layer,
state=state,
is_training=False,
prng_key=jax.random.PRNGKey(456),
inputs=decoder_inputs,
method="extend_step",
)
decoder_inputs["cached_states"] = updated_states
q_proj, k_proj, v_proj = outputs
k_proj = jnp.expand_dims(k_proj[:, t, :, :], axis=1)
v_proj = jnp.expand_dims(v_proj[:, t, :, :], axis=1)
# Generate input sequences.
batch, seq_len = 2, 10
prng_key, data_key = jax.random.split(prng_key)
query = jax.random.uniform(data_key, [batch, seq_len, model_dim])

decoder_outputs = decoder_outputs.at[t].set(
jnp.squeeze(jnp.stack([q_proj, k_proj, v_proj]), axis=2)
)
decoder_out_transposed = jnp.transpose(decoder_outputs, [1, 2, 0, 3, 4])
assert_allclose(
decoder_out_transposed,
forward_outputs,
atol=1e-6,
)

def test_prefill_states(self):
model_dim = 16
num_heads = 2
per_head_dim = 3
seq_len = 4
batch_size = 2
rank = 2
alpha = 4
enable_lora = dict(query=True, key=False, value=True)
num_enabled = sum(enable_lora.values())
inputs = jax.random.normal(jax.random.PRNGKey(456), (batch_size, seq_len, model_dim))

layer_cfg = LoraFusedQKVLinear.default_config().set(
name="test",
query_dim=model_dim,
key_dim=model_dim,
value_dim=model_dim,
num_heads=num_heads,
per_head_dim=per_head_dim,
)
layer_cfg.adapter.set(rank=rank, alpha=alpha, enable_lora=enable_lora)
layer = layer_cfg.instantiate(parent=None)
state = layer.initialize_parameters_recursively(
prng_key=jax.random.PRNGKey(123), prebuilt=None
)
state["adapter"]["lora_up"]["weight"] = jax.random.normal(
jax.random.PRNGKey(1), (num_enabled, rank, num_heads, per_head_dim)
)
forward_outputs, _ = jax.jit(partial(F, layer, is_training=False))(
state=state,
prng_key=jax.random.PRNGKey(456),
inputs=(inputs,),
# Compute layer outputs.
fwd_outputs, _ = F(
layer,
inputs=dict(query=query),
is_training=False,
state=layer_params,
prng_key=prng_key,
)

time_step = jnp.arange(batch_size)
(initial_cache_states, initial_outputs), _ = F(
# Compute extend_step.
(cached_states, _), _ = F(
layer,
state=state,
inputs=dict(time_step=None, query=query, attention_logit_biases=None),
is_training=False,
prng_key=jax.random.PRNGKey(456),
inputs=dict(time_step=time_step, query=inputs),
state=layer_params,
prng_key=prng_key,
method="init_states",
)
time_step_mask = jnp.arange(seq_len) < time_step[:, None]
# [batch, tgt_len, num_heads, per_head_dim].
decoder_outputs = initial_outputs.query * time_step_mask[..., None, None]
decoder_inputs = dict(cached_states=initial_cache_states)
while jnp.any(time_step < seq_len):
decoder_inputs["query"] = jnp.take_along_axis(
inputs, time_step[:, None, None], axis=1, mode="clip"
step_data = []
for i in range(seq_len):
step_inputs = dict(
cached_states=cached_states,
query=query[:, i : i + 1],
attention_logit_biases=None,
)
(updated_states, outputs), _ = F(
(cached_states, step_outs), _ = F(
layer,
state=state,
prng_key=jax.random.PRNGKey(0),
state=layer_params,
inputs=step_inputs,
is_training=False,
prng_key=jax.random.PRNGKey(456),
inputs=decoder_inputs,
method="extend_step",
)
decoder_inputs["cached_states"] = updated_states
q_proj, _, _ = outputs

# [batch, tgt_len, 1, 1].
oh_indices = jax.nn.one_hot(time_step, seq_len)[:, :, None, None]
decoder_outputs = decoder_outputs + q_proj * oh_indices
time_step = time_step + 1

assert_allclose(
decoder_outputs,
forward_outputs.query,
atol=1e-6,
)
step_data.append(step_outs.data)
step_data = jnp.concatenate(step_data, axis=1)
self.assertEqual(step_data.dtype, fwd_outputs.data.dtype)
assert_allclose(step_data, fwd_outputs.data)


class LoraMultiheadOutputLinearTest(TestCase):
Expand Down
9 changes: 7 additions & 2 deletions axlearn/common/ssm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig

from axlearn.common.attention import KVCache
from axlearn.common.attention_bias import make_causal_biases
from axlearn.common.config import InstantiableConfig
from axlearn.common.module import functional as F
Expand Down Expand Up @@ -930,7 +931,9 @@ def test_extend_step(self, dtype: jnp.dtype):
cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
cfg.layer.feed_forward.hidden_dim = hidden_dim
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None)
cfg.layer.self_attention.attention.set(
kv_cache=KVCache.default_config().set(cache_dtype=dtype)
)
_test_extend_step(cfg, model_dim=model_dim, dtype=dtype)

@parameterized.parameters(jnp.float32, jnp.bfloat16)
Expand All @@ -953,7 +956,9 @@ def test_prefill(self, dtype: jnp.dtype):
cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
cfg.layer.feed_forward.hidden_dim = hidden_dim
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None)
cfg.layer.self_attention.attention.set(
kv_cache=KVCache.default_config().set(cache_dtype=dtype)
)
_test_prefill_states(cfg, model_dim=model_dim, dtype=dtype)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
Expand All @@ -323,6 +321,8 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_e
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention'
model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
Expand All @@ -323,6 +321,8 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_e
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention'
model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
Expand All @@ -323,6 +321,8 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_e
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention'
model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache'
model.decoder.transformer.layer.self_attention.attention.num_heads: 32
model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False
model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
Expand Down
Loading

0 comments on commit 624bec8

Please sign in to comment.