Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stride issues in flash_attn_interface #58

Closed
wants to merge 2 commits into from

Conversation

clintg6
Copy link

@clintg6 clintg6 commented May 31, 2024

What:

Ensures tensors are contiguous in memory with matching strides during the backward pass.

Fixes #40

Why:

Multiple users/customers have been facing issues while training with the axolotl package see #40

maybe_contiguous fails to adequately check if tensors are contiguous in certain packing scenarios

Changes:

  • flash_attn_interface.py
    • _flash_attn_backward: stride check, contiguous check
    • _flash_attn_varlen_backward: contiguous check

Testing:

  • Manual testing: Successfully fine tuned Phi-2, StableLM, and TinyLlama in axolotl
  • Automated testing: Benchmarking scripts ran as expected

@clintg6 clintg6 self-assigned this May 31, 2024
@micmelesse
Copy link
Collaborator

Should we close this pr. CK is upstreamed and any changes should be for upstream right?

@clintg6
Copy link
Author

clintg6 commented Jan 30, 2025

Yes closing now.

@clintg6 clintg6 closed this Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Issue]: Expected dout_seq_stride == out_seq_stride to be true, but got false
2 participants