Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logit_bias in v1 Sampler #13079

Merged
merged 5 commits into from
Feb 14, 2025
Merged

Conversation

houseroad
Copy link
Contributor

@houseroad houseroad commented Feb 11, 2025

Introduce logit_bias support in v1 Sampler.

Tested with new test cases:
pytest tests/v1/sample/test_sampler.py -k "test_sampler_logit_bias"
pytest tests/v1/worker/test_gpu_input_batch.py

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Let some minor comments.

vllm/v1/worker/gpu_input_batch.py Outdated Show resolved Hide resolved
vllm/v1/sample/sampler.py Outdated Show resolved Hide resolved
Comment on lines +177 to +181
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment like TODO: This is extremely slow. Optimize this.?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @robertgshaw2-redhat Although this implementation is a bit slow, I'm comfortable merging the PR since I haven't found a way to optimize it yet, and getting the functionality in is our top priority. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. We can write some customized kernel to handle such things in c++. Also we may change the representation of logit_bias from dict to key value pair.

I can create some TODO as follow up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad Sg. Could you please update the PR?

Copy link
Member

@njhill njhill Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I think that's fine since we need the functionality asap. But I think it should be simple to vectorize this without any custom kernel.

We just need to maintain in the batch three one-dim tensors of the same length:

  • all the bias values concatenated (b)
  • corresponding request indices (s)
  • corresponding token ids (t)

We only need to update these when any requests with logit bias are added or removed from the batch.

Then we can just do logits[(s, t)] += b

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR. Please let me know if you want me directly jump to the optimized solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill @houseroad Considering that the code is pretty isolated, I think we can merge the PR first and have a followup PR to optimize it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill Re your idea: Each request's logits_bias has different lengths. How do we handle that (with the persistent batch)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I could be missing something... let's merge this and then I can open another PR :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have ragged format representation, 3 tensors, tensor a for length/offset of each batch, tensor b for token ids, tensor c for bias. And one torch op takes these 3 tensors as inputs, and we leverage C++ logic to handle it. It should be much faster then the current logic.

Maybe in the SamplingMeta or param, we should just preproces things like this.

The additional overhead is that once we update the batch, we may need to generate new tensors, which should be acceptable.

@houseroad houseroad force-pushed the v1_logit_bias branch 2 times, most recently from de10bc2 to 05acc1f Compare February 12, 2025 08:42
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the PR!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 13, 2025
@WoosukKwon
Copy link
Collaborator

@houseroad Seems like we need to fix a unit test: https://buildkite.com/vllm/ci/builds/13363#0194fe70-1618-4cb6-ae2d-8380ee8a2041 😓

Signed-off-by: Lu Fang <[email protected]>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 14, 2025 12:34
@simon-mo simon-mo merged commit 6224a9f into vllm-project:main Feb 14, 2025
31 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants