Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers #9982

Merged
merged 79 commits into from
Nov 12, 2024

Conversation

sroy745
Copy link
Collaborator

@sroy745 sroy745 commented Nov 4, 2024

In this pr we update Mllama to run on both xFormers and Flash Attention backend. Currently it runs only with the xFormers backend. This pr makes the following changes

  1. Updates mllama.py to run with both xFormers and FlashAttention. There were 2 changes needed for this (a) update attention_with_mask to use appropriately update the cache depending on the backend being used (b) update the shape of the query when computing the attention. Currently it works fine with xFormers because xFormer backend does not enforce the query to be of shape [num_tokens, hidden_size].
  2. Updated enc_dec_model_runner.py to no longer force the backend to be xFormers for Mllama
  3. Updated the test test_mllama.py to run with both xFormers and FlashAttention.
  4. Updated test_e2e_correctness.py to clear the backend cache at the begining of each test run. Without it looks like the cached backend value gets reused across tests.

sroy745 added 30 commits May 28, 2024 20:39
@sroy745
Copy link
Collaborator Author

sroy745 commented Nov 6, 2024

Thanks @heheda12345 for pointing out the issue.

To summarize if we run python3 examples/test_mllama.py then on H100 we find that the output for flash-attention backend starts diverging from that of the xFormers backend and we want to debug the reason for this.

I added some logs to print out the logits in mllama.py to debug this further. The examples/test_mllama.py is also included in the pr. I will remove them once the debugging is over.

Output for Flash Attention Run

INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[ 13.8125,  12.7500,  12.3750,  ...,  -9.3750, -10.1250, -10.3125]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[  578,  1115,  9062,  ..., 98323, 48046, 89920]], device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=578, logprobs={578: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[18.1250, 16.5000, 15.8750,  ..., -8.6250, -8.6875, -9.0000]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[  2217,   1176,   5448,  ...,  44326, 116655,  90609]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=2217, logprobs={2217: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[19.1250, 18.3750, 17.6250,  ..., -7.3125, -7.5000, -7.7812]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[  5039,  62991,    374,  ..., 108112, 111896, 123635]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=5039, logprobs={5039: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[20.0000, 16.6250, 16.0000,  ..., -7.1250, -7.1875, -8.9375]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[   264,    279,   1403,  ...,  88885, 108602,  64170]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=264, logprobs={264: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[18.8750, 16.2500, 15.9375,  ..., -9.0625, -9.1875, -9.8125]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[  8762,   3345,  40132,  ..., 124479,  82422,  83788]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=8762, logprobs={8762: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[ 18.2500,  17.8750,  16.8750,  ...,  -9.1875,  -9.2500, -10.0000]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[ 34353,  40132,  32498,  ...,   1714, 118633,  63345]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=34353, logprobs={34353: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[ 24.5000,  15.3750,  15.1250,  ..., -10.3750, -10.9375, -11.0625]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[   569,    329,   2402,  ...,  82107, 120381,  80088]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=569, logprobs={569: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[19.5000, 16.5000, 16.2500,  ..., -9.5000, -9.5625, -9.8125]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[37085,   304,    11,  ..., 19811, 20133, 41077]], device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=37085, logprobs={37085: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:03:40 mllama.py:1146] sorted_logits tensor([[18.0000, 18.0000, 17.0000,  ..., -8.8750, -9.0625, -9.6875]],
INFO 11-06 07:03:40 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:03:40 mllama.py:1147] sorted_indices tensor([[   304,  24269,     13,  ..., 108602,  62785,  74818]],
INFO 11-06 07:03:40 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:03:40 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=304, logprobs={304: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
Processed prompts: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.13it/s, est. speed input: 13.54 toks/s, output: 10.15 toks/s]
 The image shows a male mallard duck in

Output for xFormers Run

Processed prompts:   0%|                                                                            | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[ 13.8125,  12.6875,  12.4375,  ...,  -9.3125, -10.0625, -10.2500]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[  578,  1115,  9062,  ..., 98323, 48046, 89920]], device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=578, logprobs={578: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[18.1250, 16.5000, 15.8750,  ..., -8.6250, -8.6250, -9.0625]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[  2217,   1176,   5448,  ...,  44326, 116655,  90609]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=2217, logprobs={2217: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[19.1250, 18.3750, 17.6250,  ..., -7.3125, -7.5625, -7.7812]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[  5039,  62991,    374,  ..., 108112, 111896, 123635]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=5039, logprobs={5039: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[20.0000, 16.6250, 16.0000,  ..., -7.1250, -7.1875, -8.9375]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[   264,    279,   1403,  ...,  88885, 108602,  64170]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=264, logprobs={264: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[18.8750, 16.2500, 15.9375,  ..., -9.0625, -9.2500, -9.8125]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[  8762,   3345,  40132,  ..., 124479,  82422,  83788]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=8762, logprobs={8762: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[ 18.2500,  18.0000,  16.7500,  ...,  -9.1875,  -9.3125, -10.0625]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[ 34353,  40132,  32498,  ...,  64460, 118633,  63345]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=34353, logprobs={34353: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[ 24.5000,  15.3125,  15.0625,  ..., -10.3750, -11.0000, -11.0000]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[   569,    329,   2402,  ...,  82107,  80088, 120381]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=569, logprobs={569: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[19.5000, 16.3750, 16.2500,  ..., -9.5000, -9.6250, -9.7500]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[37085,   304,    11,  ..., 19811, 20133, 41077]], device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=37085, logprobs={37085: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
INFO 11-06 07:06:31 mllama.py:1146] sorted_logits tensor([[18.1250, 18.0000, 17.0000,  ..., -8.8750, -9.0625, -9.6875]],
INFO 11-06 07:06:31 mllama.py:1146]        device='cuda:0', dtype=torch.bfloat16)
INFO 11-06 07:06:31 mllama.py:1147] sorted_indices tensor([[ 24269,    304,     13,  ..., 108602,  62785,  74818]],
INFO 11-06 07:06:31 mllama.py:1147]        device='cuda:0')
INFO 11-06 07:06:31 mllama.py:1149] next_tokens SamplerOutput(outputs=[CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=0, output_token=24269, logprobs={24269: Logprob(logprob=inf, rank=None, decoded_token=None)})], prompt_logprobs=None)], sampled_token_probs=None, sampled_token_ids=None, spec_decode_worker_metrics=None)
Processed prompts: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.09it/s, est. speed input: 13.10 toks/s, output: 9.82 toks/s]
 The image shows a male mallard duck swimming

The mis-match starts happening in the last token here (in vs swimming). The logits of these 2 token-ids are very close (in FlashAttention run both are 18 while in the xFormers run one is 18.125 vs 18). I think this diff in the logits is causing them to start differing at this token position and beyond. Overall the logits seem close for the 2 backends.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work and deep investigation to the output difference. Given the logprobs of the two tokens are similar, I believe the difference is caused by some precision issue instead of some bugs, and we can continue on this pr. Added some small suggestions.

tests/encoder_decoder/test_e2e_correctness.py Outdated Show resolved Hide resolved
vllm/model_executor/models/mllama.py Outdated Show resolved Hide resolved
vllm/worker/enc_dec_model_runner.py Outdated Show resolved Hide resolved
@sroy745
Copy link
Collaborator Author

sroy745 commented Nov 8, 2024

Thanks @heheda12345 for the review. Addressed your comments. PTAL.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the great work!
CC @ywang96

Copy link

mergify bot commented Nov 9, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sroy745.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 9, 2024
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sroy745 LGTM! Could you please rebase this RP for the merge? Thanks

@mergify mergify bot removed the needs-rebase label Nov 11, 2024
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 12, 2024
Signed-off-by: Sourashis Roy <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
Signed-off-by: Sourashis Roy <[email protected]>
@sroy745
Copy link
Collaborator Author

sroy745 commented Nov 12, 2024

@heheda12345 / @ywang96 I resynced to head and tests are passing. One thing is that I had to update tests/test_config.py to skip test_is_encoder_decoder for rcom. The reason is that this test now starts failing when it tries to run for "meta-llama/Llama-3.2-11B-Vision". The import for mllama.py fails when it tries to import the newly added imports for FlashAttentionMetadata and xFormersMetadata. Skipping this should be fine right because the encoder-decoder models are not supported in rcom? PTAL

@ywang96 ywang96 merged commit b41fb9d into vllm-project:main Nov 12, 2024
48 of 49 checks passed
dsikka pushed a commit to neuralmagic/vllm that referenced this pull request Nov 13, 2024
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 13, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants