-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support logit_bias in v1 Sampler #13079
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Let some minor comments.
for token_id, bias in logit_bias.items(): | ||
logits[i, token_id] += bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment like TODO: This is extremely slow. Optimize this.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill @robertgshaw2-redhat Although this implementation is a bit slow, I'm comfortable merging the PR since I haven't found a way to optimize it yet, and getting the functionality in is our top priority. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree. We can write some customized kernel to handle such things in c++. Also we may change the representation of logit_bias from dict to key value pair.
I can create some TODO as follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad Sg. Could you please update the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon I think that's fine since we need the functionality asap. But I think it should be simple to vectorize this without any custom kernel.
We just need to maintain in the batch three one-dim tensors of the same length:
- all the bias values concatenated (b)
- corresponding request indices (s)
- corresponding token ids (t)
We only need to update these when any requests with logit bias are added or removed from the batch.
Then we can just do logits[(s, t)] += b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the PR. Please let me know if you want me directly jump to the optimized solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill @houseroad Considering that the code is pretty isolated, I think we can merge the PR first and have a followup PR to optimize it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill Re your idea: Each request's logits_bias
has different lengths. How do we handle that (with the persistent batch)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon I could be missing something... let's merge this and then I can open another PR :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can have ragged format representation, 3 tensors, tensor a for length/offset of each batch, tensor b for token ids, tensor c for bias. And one torch op takes these 3 tensors as inputs, and we leverage C++ logic to handle it. It should be much faster then the current logic.
Maybe in the SamplingMeta or param, we should just preproces things like this.
The additional overhead is that once we update the batch, we may need to generate new tensors, which should be acceptable.
de10bc2
to
05acc1f
Compare
Signed-off-by: Lu Fang <[email protected]>
05acc1f
to
7b17c04
Compare
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the PR!
@houseroad Seems like we need to fix a unit test: https://buildkite.com/vllm/ci/builds/13363#0194fe70-1618-4cb6-ae2d-8380ee8a2041 😓 |
Signed-off-by: Lu Fang <[email protected]>
Introduce logit_bias support in v1 Sampler.
Tested with new test cases:
pytest tests/v1/sample/test_sampler.py -k "test_sampler_logit_bias"
pytest tests/v1/worker/test_gpu_input_batch.py