-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FP8][Kernel] Dynamic kv cache scaling factors computation #11906
[FP8][Kernel] Dynamic kv cache scaling factors computation #11906
Conversation
…ching (#317) * Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Gregory Shtrasberg <[email protected]>
* Using tensors in the explicit cache function calls from mllama implementation * Properly creating the tensor Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a great feature to add (and removal of unused pathway). I do think it is a bit opinionated to only calculate the dynamic scales on the first inference - at the least we should document the recommended setup to properly prime the scales. Possibly we could have a mode where we always use dynamic scaling or require N tokens seen before stopping calibration.
The choice of when each attention backend is passing in hardcoded enable_kv_scales_calculation=True
vs enable_kv_scales_calculation=False
seems very unclear at a glance if you could add comments for that. We also should keep the error checking for backends that don't support quantization at all.
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. | ||
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", | ||
"meta-llama/Llama-2-7b-chat-hf", | ||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") | ||
"meta-llama/Llama-2-7b-chat-hf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this case can be removed
@@ -181,6 +182,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: | |||
num_decode_tokens=self.num_decode_tokens, | |||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], | |||
multi_modal_placeholder_index_maps=None, | |||
enable_kv_scales_calculation=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this hardcoded to True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This metadata field is used to disable the calculation on profile and graph capture stages (assuming it is already globally enabled), to not prime the scales with dummy data. So it's on by default (on platforms that support it), and is switched off during these stages.
It does not control the feature at large, for that there is a config parameter
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Gregory Shtrasberg <[email protected]>
Sorry for enabling the tests so late, but it seems there are several valid errors. Could you please look into updating the tests as well? |
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Head branch was pushed to by a user without write access
…ds that don't support tensors (Flashinfer), since on CUDA during graph capturing phase referencing tensor values is impossible Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Gregory Shtrasberg <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if green, thanks for dealing with all the conflicts!
This PR deprecates loading kv cache scales from json in favor of adding the option to dynamically compute them based on the first real input to the attention layer.
Our tests showed that the dynamic range computed based on the first input to each layer is representative of the entire model, and the accuracy is comparable with scaling factors computed using Quark quantizer (such as in HF amd/*-FP8-KV models)
Accuracy measured using the P3L benchmark that allows measuring accuracy on decode steps, using the data in the kv cache
K and V scale parameters are made on-device tensors in order to allow changing their values after the graph has been captured. This also lays the foundation to using per-channel quantization with tensor-like scales.
The effect is most visible on models with dynamic value ranges outside of the scope of fp8e4m3, such as Quen2 7B:
Using dynamic calculation reduces the PPL score from 34.84 to 22.62
On LLama based models the improvement is much smaller, due to the fact that identity scales work just as well, but still can be in single digit percents, on par with using the scales from a quantized model