-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Move more control of kv cache initialization from model_executor to EngineCore #11960
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
0827ca8
can run
heheda12345 6025d5e
update comment
heheda12345 6024290
update comment
heheda12345 03130cd
format
heheda12345 1229600
format
heheda12345 e3764d4
bind kv cache to model runner
heheda12345 fec7d2d
determine_available_memory
heheda12345 4294435
update kv_cache_utils
heheda12345 f79dff2
Update vllm/v1/utils.py
heheda12345 97176da
Update vllm/v1/worker/gpu_model_runner.py
heheda12345 e6179a8
add some comments
heheda12345 eb37f0c
add more comments
heheda12345 88fd1b8
format
heheda12345 3493061
update comment
heheda12345 105814a
update comment
heheda12345 9ff57d0
small updates
heheda12345 044876e
update comments and function names
heheda12345 62f2c09
format
heheda12345 138a4ac
format
heheda12345 00f2bda
Merge branch 'main' of github.com:vllm-project/vllm into v1_kv_init
heheda12345 2aa7509
update docstring
heheda12345 e8a1eb0
Merge branch 'main' of github.com:vllm-project/vllm into v1_kv_init
heheda12345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import List | ||
|
||
import torch | ||
|
||
from vllm.v1.utils import bind_kv_cache | ||
|
||
|
||
def test_bind_kv_cache(): | ||
from vllm.attention import Attention | ||
|
||
ctx = { | ||
'layers.0.self_attn': Attention(32, 128, 0.1), | ||
'layers.1.self_attn': Attention(32, 128, 0.1), | ||
'layers.2.self_attn': Attention(32, 128, 0.1), | ||
'layers.3.self_attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'layers.0.self_attn': torch.zeros((1, )), | ||
'layers.1.self_attn': torch.zeros((1, )), | ||
'layers.2.self_attn': torch.zeros((1, )), | ||
'layers.3.self_attn': torch.zeros((1, )), | ||
} | ||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(kv_cache, ctx, runner_kv_caches) | ||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.0.self_attn'] | ||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.1.self_attn'] | ||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.2.self_attn'] | ||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.3.self_attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] | ||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] | ||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] | ||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] | ||
|
||
|
||
def test_bind_kv_cache_non_attention(): | ||
from vllm.attention import Attention | ||
|
||
# example from Jamba PP=2 | ||
ctx = { | ||
'model.layers.20.attn': Attention(32, 128, 0.1), | ||
'model.layers.28.attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'model.layers.20.attn': torch.zeros((1, )), | ||
'model.layers.28.attn': torch.zeros((1, )), | ||
} | ||
|
||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(kv_cache, ctx, runner_kv_caches) | ||
|
||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.20.attn'] | ||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.28.attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] | ||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is heavy code duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the same test as v0, but the detail workflow is very different due to the dict->list and list->dict difference.