Skip to content

Commit

Permalink
add type info for backward
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 6, 2025
1 parent 156a9bc commit e6a67b3
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ def _flash_attn_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down Expand Up @@ -379,11 +379,11 @@ def _flash_attn_varlen_backward(
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
descale_q=None,
descale_k=None,
descale_v=None,
descale_p=None,
descale_do=None
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_p: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
Expand Down

0 comments on commit e6a67b3

Please sign in to comment.