-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Flash attention (#725)
* Add support for Flash attention * Fix attention type can be both sparse and flash * Updates from running pre-commit on modified files * Update README.md Co-authored-by: Stella Biderman <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
- Loading branch information
1 parent
589c70a
commit efd5911
Showing
7 changed files
with
273 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import flash_attn_cuda | ||
|
||
|
||
def _flash_attn_forward( | ||
q, | ||
k, | ||
v, | ||
out, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
return_softmax, | ||
num_splits=0, | ||
generator=None, | ||
): | ||
""" | ||
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means | ||
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. | ||
Don't change it unless you know what you're doing. | ||
""" | ||
softmax_lse, *rest = flash_attn_cuda.fwd( | ||
q, | ||
k, | ||
v, | ||
out, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
False, | ||
causal, | ||
return_softmax, | ||
num_splits, | ||
generator, | ||
) | ||
# if out.isnan().any() or softmax_lse.isnan().any(): | ||
# breakpoint() | ||
S_dmask = rest[0] if return_softmax else None | ||
return out, softmax_lse, S_dmask | ||
|
||
|
||
def _flash_attn_backward( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_lse, | ||
dq, | ||
dk, | ||
dv, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
num_splits=0, | ||
generator=None, | ||
): | ||
""" | ||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or | ||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. | ||
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel | ||
as num_splits=3), so effectively the choices are 0, 1, and 2. | ||
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. | ||
""" | ||
_, _, _, softmax_d = flash_attn_cuda.bwd( | ||
dout, | ||
q, | ||
k, | ||
v, | ||
out, | ||
softmax_lse, | ||
dq, | ||
dk, | ||
dv, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
max_seqlen_q, | ||
max_seqlen_k, | ||
dropout_p, | ||
softmax_scale, | ||
False, | ||
causal, | ||
num_splits, | ||
generator, | ||
) | ||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): | ||
# breakpoint() | ||
return dq, dk, dv, softmax_d | ||
|
||
|
||
class FlashAttnQKVPackedFunc(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx, | ||
qkv, | ||
cu_seqlens, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale, | ||
causal, | ||
return_softmax, | ||
): | ||
# Save rng_state because the backward pass will regenerate the dropout mask | ||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None | ||
if softmax_scale is None: | ||
softmax_scale = qkv.shape[-1] ** (-0.5) | ||
out, softmax_lse, S_dmask = _flash_attn_forward( | ||
qkv[:, 0], | ||
qkv[:, 1], | ||
qkv[:, 2], | ||
torch.empty_like(qkv[:, 0]), | ||
cu_seqlens, | ||
cu_seqlens, | ||
max_seqlen, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale, | ||
causal=causal, | ||
return_softmax=return_softmax, | ||
) | ||
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) | ||
ctx.dropout_p = dropout_p | ||
ctx.max_seqlen = max_seqlen | ||
ctx.softmax_scale = softmax_scale | ||
ctx.causal = causal | ||
return out if not return_softmax else (out, softmax_lse, S_dmask) | ||
|
||
@staticmethod | ||
def backward(ctx, dout, *args): | ||
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors | ||
if rng_state is not None: | ||
cur_rng_state = torch.cuda.get_rng_state() | ||
torch.cuda.set_rng_state(rng_state) | ||
dqkv = torch.empty_like(qkv) | ||
_flash_attn_backward( | ||
dout, | ||
qkv[:, 0], | ||
qkv[:, 1], | ||
qkv[:, 2], | ||
out, | ||
softmax_lse, | ||
dqkv[:, 0], | ||
dqkv[:, 1], | ||
dqkv[:, 2], | ||
cu_seqlens, | ||
cu_seqlens, | ||
ctx.max_seqlen, | ||
ctx.max_seqlen, | ||
ctx.dropout_p, | ||
ctx.softmax_scale, | ||
ctx.causal, | ||
) | ||
if rng_state is not None: | ||
torch.cuda.set_rng_state(cur_rng_state) | ||
return dqkv, None, None, None, None, None, None | ||
|
||
|
||
def flash_attn_unpadded_qkvpacked_func( | ||
qkv, | ||
cu_seqlens, | ||
max_seqlen, | ||
dropout_p, | ||
softmax_scale=None, | ||
causal=False, | ||
return_attn_probs=False, | ||
): | ||
return FlashAttnQKVPackedFunc.apply( | ||
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
"bslongformer", | ||
"gmlp", | ||
"amlp", | ||
"flash", | ||
] | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flash-attn==0.2.2 |