Skip to content

Commit

Permalink
workaround for open-clip-torch update
Browse files Browse the repository at this point in the history
fixes the attn_mask shape error
  • Loading branch information
kijai committed Jul 6, 2024
1 parent a8780cd commit 59ca341
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
transformers>=4.28.1
fsspec>=2023.4.0
kornia>=0.6.9
open-clip-torch==2.24.0
open-clip-torch>=2.24.0
Pillow>=9.4.0
pytorch-lightning>=2.2.1
omegaconf
accelerate
accelerate
26 changes: 24 additions & 2 deletions sgm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
import comfy.model_management
device = comfy.model_management.get_torch_device()

import comfy.ops
ops = comfy.ops.manual_cast

class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -577,7 +580,10 @@ def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
try:
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
except:
x = self.text_transformer_forward_batch_first(x, attn_mask=self.model.attn_mask)
if self.legacy:
x = x[self.layer]
x = self.model.ln_final(x)
Expand Down Expand Up @@ -612,6 +618,22 @@ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
x = r(x, attn_mask=attn_mask)
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
return outputs

def text_transformer_forward_batch_first(self, x: torch.Tensor, attn_mask=None):
x = x.permute(1, 0, 2) # LND -> NLD
outputs = {}
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - 1:
outputs["penultimate"] = x
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
outputs["last"] = x
return outputs

def encode(self, text):
return self(text)
Expand Down Expand Up @@ -908,7 +930,7 @@ def __init__(
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
self.channel_mapper = ops.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
Expand Down

0 comments on commit 59ca341

Please sign in to comment.