forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1] Move more control of kv cache initialization from model_executor…
… to EngineCore (vllm-project#11960) Signed-off-by: Chen Zhang <[email protected]> Co-authored-by: Cody Yu <[email protected]>
- Loading branch information
Showing
12 changed files
with
515 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import List | ||
|
||
import torch | ||
|
||
from vllm.v1.utils import bind_kv_cache | ||
|
||
|
||
def test_bind_kv_cache(): | ||
from vllm.attention import Attention | ||
|
||
ctx = { | ||
'layers.0.self_attn': Attention(32, 128, 0.1), | ||
'layers.1.self_attn': Attention(32, 128, 0.1), | ||
'layers.2.self_attn': Attention(32, 128, 0.1), | ||
'layers.3.self_attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'layers.0.self_attn': torch.zeros((1, )), | ||
'layers.1.self_attn': torch.zeros((1, )), | ||
'layers.2.self_attn': torch.zeros((1, )), | ||
'layers.3.self_attn': torch.zeros((1, )), | ||
} | ||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(kv_cache, ctx, runner_kv_caches) | ||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.0.self_attn'] | ||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.1.self_attn'] | ||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.2.self_attn'] | ||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.3.self_attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] | ||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] | ||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] | ||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] | ||
|
||
|
||
def test_bind_kv_cache_non_attention(): | ||
from vllm.attention import Attention | ||
|
||
# example from Jamba PP=2 | ||
ctx = { | ||
'model.layers.20.attn': Attention(32, 128, 0.1), | ||
'model.layers.28.attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'model.layers.20.attn': torch.zeros((1, )), | ||
'model.layers.28.attn': torch.zeros((1, )), | ||
} | ||
|
||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(kv_cache, ctx, runner_kv_caches) | ||
|
||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.20.attn'] | ||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.28.attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] | ||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.