Skip to content

Commit

Permalink
update third-party packages
Browse files Browse the repository at this point in the history
  • Loading branch information
flymin committed Jan 25, 2024
1 parent fabe2ce commit f667df3
Show file tree
Hide file tree
Showing 4,990 changed files with 1,021,147 additions and 23 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
1 change: 1 addition & 0 deletions third_party/diffusers/src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
final_dropout: bool = False,
):
super().__init__()
self._args = {k: v for k, v in locals().items() if k != "self" and not k.startswith("_")}
self.only_cross_attention = only_cross_attention

self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
Expand Down
154 changes: 140 additions & 14 deletions third_party/diffusers/src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import math
import os

import torch
import torch.nn.functional as F
Expand All @@ -27,6 +29,9 @@
if is_xformers_available():
import xformers
import xformers.ops
CHECK_XFORMERS = int(os.getenv("CHECK_XFORMERS", "0")) == 1
SPLIT_SIZE = int(os.getenv("SPLIT_SIZE", -1))
ERROR_TOLERANCE = 0.002
else:
xformers = None

Expand Down Expand Up @@ -160,6 +165,13 @@ def __init__(
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
self._check_xformers = False

def set_check_xformers(self):
if CHECK_XFORMERS:
self._check_xformers = True
else:
self._check_xformers = False

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
Expand Down Expand Up @@ -317,13 +329,55 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
run_check = False
if self._check_xformers:
use_xformers = hasattr(self, "processor") and isinstance(
self.processor, (
LoRAXFormersAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
)
)
if use_xformers:
run_check = True
else:
logger.warn("You set `CHECK_XFORMERS=1` but your attn does not "
"enable xformers.")
if run_check:
xformers_out = self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
self.set_use_memory_efficient_attention_xformers(False)
with torch.no_grad():
normal_out = self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
abs_error = (xformers_out - normal_out).abs()
if (abs_error > ERROR_TOLERANCE).any():
msg = f"xformers error, max abs = {abs_error.max():.2e} "
msg += f"h size: {hidden_states.shape} "
if encoder_hidden_states is not None:
msg += f"e size: {encoder_hidden_states.shape}"
logger.error(msg)
self.set_use_memory_efficient_attention_xformers(True)
return xformers_out
else:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)

def batch_to_head_dim(self, tensor):
head_size = self.heads
Expand Down Expand Up @@ -998,6 +1052,38 @@ def __call__(
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
):
actual_size = hidden_states.shape[0]
if SPLIT_SIZE != -1 and actual_size > SPLIT_SIZE:
split_steps = math.ceil(actual_size / SPLIT_SIZE)
split_steps = min(split_steps, actual_size)
hidden_states_out = []
_hidden_states = hidden_states.chunk(split_steps)
if encoder_hidden_states is None:
_encoder_hidden_states = [None] * split_steps
else:
_encoder_hidden_states = encoder_hidden_states.chunk(
split_steps)
assert attention_mask is None
assert temb is None
for i in range(split_steps):
hidden_states_out.append(
self._real_call(
attn, _hidden_states[i], _encoder_hidden_states[i],
attention_mask, temb)
)
return torch.cat(hidden_states_out, dim=0)
else:
return self._real_call(
attn, hidden_states, encoder_hidden_states, attention_mask,
temb)

def _real_call(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

Expand Down Expand Up @@ -1038,15 +1124,55 @@ def __call__(
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
# NOTE: xformers induce large error when bs is large. so we do not add
# head to batch and split batch size if necessary.

def _split_head(tensor):
head_size = attn.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
return tensor

def _back_head(tensor):
batch_size, seq_len, head_size, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size * dim)
return tensor

# query = attn.head_to_batch_dim(query).contiguous()
# key = attn.head_to_batch_dim(key).contiguous()
# value = attn.head_to_batch_dim(value).contiguous()
query = _split_head(query)
key = _split_head(key)
value = _split_head(value)

if attention_mask is not None:
# from cutlassF
# HINT: To use an `attn_bias` with a sequence length that is not a
# multiple of 8, you need to ensure memory is aligned by slicing a
# bigger tensor.
# Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]`
# instead of `torch.zeros([1, 1, 5, 5])`
b, l1, l2 = attention_mask.shape
if attention_mask.stride(-2) % 8 != 0:
l1_align = (l1 // 8 + 1) * 8
l2_align = (l2 // 8 + 1) * 8
attention_mask_align = torch.zeros(
(b, l1_align, l2_align), dtype=attention_mask.dtype,
device=attention_mask.device)
attention_mask_align[:, :l1, :l2] = attention_mask
attention_mask = attention_mask_align

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask[:, :l1, :l2],
op=self.attention_op, scale=attn.scale)
else:
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask,
op=self.attention_op, scale=attn.scale)

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# hidden_states = attn.batch_to_head_dim(_hidden_states)
hidden_states = _back_head(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down
3 changes: 2 additions & 1 deletion third_party/diffusers/src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
from collections import OrderedDict
from typing import Union
import warnings

from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging import version
Expand Down Expand Up @@ -222,7 +223,7 @@
import torch

if version.Version(torch.__version__) < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
warnings.warn("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
Expand Down
2 changes: 1 addition & 1 deletion third_party/xformers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Example requirement, can be anything that pip knows
# install with `pip install -r requirements.txt`, and make sure that CI does the same
torch >= 1.12
# torch >= 1.12
numpy
pyre-extensions == 0.0.29
5 changes: 5 additions & 0 deletions third_party/xformers/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def get_cuda_version(cuda_dir) -> int:


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
torch_version = torch.__version__ .split('+')[0]
torch_version = tuple(int(v) for v in torch_version.split("."))
if torch_version < (1, 12):
print("You Pytorch version cannot load flash attention for xformers.")
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version > 1100:
Expand Down
1 change: 1 addition & 0 deletions third_party/xformers/third_party/cutlass
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
name: Bug report
about: Create a bug report to help us improve CUTLASS
title: "[BUG]"
labels: "? - Needs Triage, bug"
assignees: ''

---

**Describe the bug**
A clear and concise description of what the bug is.

**Steps/Code to reproduce bug**
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.

**Expected behavior**
A clear and concise description of what you expected to happen.

**Environment details (please complete the following information):**
- Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]

**Additional context**
Add any other context about the problem here.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
name: Documentation request
about: Report incorrect or needed documentation to improve CUTLASS
title: "[DOC]"
labels: "? - Needs Triage, documentation"
assignees: ''

---

## Report incorrect documentation

**Location of incorrect documentation**
Provide links and line numbers if applicable.

**Describe the problems or issues found in the documentation**
A clear and concise description of what you found to be incorrect.

**Steps taken to verify documentation is incorrect**
List any steps you have taken:

**Suggested fix for documentation**
Detail proposed changes to fix the documentation if you have any.

---

## Report needed documentation

**Report needed documentation**
A clear and concise description of what documentation you believe it is needed and why.

**Describe the documentation you'd like**
A clear and concise description of what you want to happen.

**Steps taken to search for needed documentation**
List any steps you have taken:
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for CUTLASS
title: "[FEA]"
labels: "? - Needs Triage, feature request"
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I wish I could use CUTLASS to do [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context, code examples, or references to existing implementations about the feature request here.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
name: Submit question
about: Ask a general question about CUTLASS
title: "[QST]"
labels: "? - Needs Triage, question"
assignees: ''

---

**What is your question?**
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# https://github.com/actions/labeler#common-examples

examples:
- examples/**

source:
- cmake/**
- include/cutlass/**

documentation:
- docs/**
- media/**

testing:
- test/**

tooling:
- tools/**
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
triage:
runs-on: ubuntu-latest
permissions: read-all|write-all
steps:
- uses: actions/labeler@master
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
Loading

0 comments on commit f667df3

Please sign in to comment.