-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add paged attention support #1355
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch L0 |
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
/te-ci pytorch L0 |
Signed-off-by: Charlene Yang <[email protected]>
v_cache: torch.Tensor | ||
The value cache tensor containing previous and the current tokens | ||
""" | ||
k_cache, v_cache = self.cache[layer_number] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k_cache, v_cache = self.cache[layer_number] | |
assert layer_number in self.cache | |
k_cache, v_cache = self.cache[layer_number] |
def __init__( | ||
self, | ||
max_batch_size: int, | ||
max_seqlen_kv: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corresponding docstring says max_sequence_length
, we should change one of those
seq_s = self.sequences[seq] - step_dict[seq] | ||
seq_e = self.sequences[seq] | ||
if qkv_format == "bshd": | ||
new_k_cache[i, seq_s:seq_e, :, :] = k[i, : step_dict[seq], :, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k[i, : step_dict[seq], :, :]
k
isn't supposed to have any tokens beyond step_dict[seq]
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for v
seq_s = self.sequences[seq] - step_dict[seq] | ||
seq_e = self.sequences[seq] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could potentially be moved into a method since this could be reused from outside like when getting the start positions of RoPE embeddings application
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
8e80771
to
612637c
Compare
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
5fff609
to
11f15fd
Compare
Signed-off-by: Charlene Yang <[email protected]>
e710273
to
33b430f
Compare
Signed-off-by: Charlene Yang <[email protected]>
6fcad33
to
f5b91c6
Compare
Signed-off-by: Charlene Yang <[email protected]>
5b4117b
to
7331a4c
Compare
Signed-off-by: Charlene Yang <[email protected]>
cbad5ea
to
0341de7
Compare
Signed-off-by: Charlene Yang <[email protected]>
c699139
to
6bd61a7
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
8f8a81e
to
93235dd
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Description
This PR adds paged attention support for FusedAttention, FlashAttention, and UnfusedDotProductAttention.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: