diff --git a/requirements.txt b/requirements.txt index 45fc648..4086dab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ transformers>=4.28.1 fsspec>=2023.4.0 kornia>=0.6.9 -open-clip-torch==2.24.0 +open-clip-torch>=2.24.0 Pillow>=9.4.0 pytorch-lightning>=2.2.1 omegaconf -accelerate +accelerate \ No newline at end of file diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index a9185e7..d88918f 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -36,6 +36,9 @@ import comfy.model_management device = comfy.model_management.get_torch_device() +import comfy.ops +ops = comfy.ops.manual_cast + class AbstractEmbModel(nn.Module): def __init__(self): super().__init__() @@ -577,7 +580,10 @@ def encode_with_transformer(self, text): x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + try: + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + except: + x = self.text_transformer_forward_batch_first(x, attn_mask=self.model.attn_mask) if self.legacy: x = x[self.layer] x = self.model.ln_final(x) @@ -612,6 +618,22 @@ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): x = r(x, attn_mask=attn_mask) outputs["last"] = x.permute(1, 0, 2) # LND -> NLD return outputs + + def text_transformer_forward_batch_first(self, x: torch.Tensor, attn_mask=None): + x = x.permute(1, 0, 2) # LND -> NLD + outputs = {} + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - 1: + outputs["penultimate"] = x + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + outputs["last"] = x + return outputs def encode(self, text): return self(text) @@ -908,7 +930,7 @@ def __init__( print( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) - self.channel_mapper = nn.Conv2d( + self.channel_mapper = ops.Conv2d( in_channels, out_channels, kernel_size=kernel_size,