From ee63998c1abcdaea144605ab4365ce6db1dceb1c Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 24 Mar 2024 20:37:59 -0400 Subject: [PATCH 1/4] Add layer-diffuse --- pyproject.toml | 3 +- .../layer_diffuse/LICENSE | 201 ++++++ .../layer_diffuse/README.md | 7 + .../layer_diffuse/__init__.py | 3 + .../layer_diffuse/layered_diffusion.py | 683 ++++++++++++++++++ .../lib_layerdiffusion/__init__.py | 0 .../lib_layerdiffusion/attention_sharing.py | 360 +++++++++ .../layer_diffuse/lib_layerdiffusion/enums.py | 23 + .../lib_layerdiffusion/models.py | 318 ++++++++ .../layer_diffuse/lib_layerdiffusion/utils.py | 135 ++++ 10 files changed, 1732 insertions(+), 1 deletion(-) create mode 100644 src/inference_core_nodes/layer_diffuse/LICENSE create mode 100644 src/inference_core_nodes/layer_diffuse/README.md create mode 100644 src/inference_core_nodes/layer_diffuse/__init__.py create mode 100644 src/inference_core_nodes/layer_diffuse/layered_diffusion.py create mode 100644 src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/__init__.py create mode 100644 src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/attention_sharing.py create mode 100644 src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/enums.py create mode 100644 src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/models.py create mode 100644 src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/utils.py diff --git a/pyproject.toml b/pyproject.toml index 9e492ea..ee93c6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ dependencies = [ "yacs", "trimesh[easy]", "albumentations", - "scikit-learn" + "scikit-learn", + "diffusers>=0.25.0" ] [project.optional-dependencies] diff --git a/src/inference_core_nodes/layer_diffuse/LICENSE b/src/inference_core_nodes/layer_diffuse/LICENSE new file mode 100644 index 0000000..b09cd78 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/inference_core_nodes/layer_diffuse/README.md b/src/inference_core_nodes/layer_diffuse/README.md new file mode 100644 index 0000000..0544bf3 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/README.md @@ -0,0 +1,7 @@ +## Layer Diffuse + +Based on or modified from: [huchenlei/ComfyUI-layerdiffuse](https://github.com/huchenlei/ComfyUI-layerdiffuse) @ 151f7460bbc9d7437d4f0010f21f80178f7a84a6 + +License: Apache-2.0 + + diff --git a/src/inference_core_nodes/layer_diffuse/__init__.py b/src/inference_core_nodes/layer_diffuse/__init__.py new file mode 100644 index 0000000..0dc1069 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/__init__.py @@ -0,0 +1,3 @@ +from .layered_diffusion import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/src/inference_core_nodes/layer_diffuse/layered_diffusion.py b/src/inference_core_nodes/layer_diffuse/layered_diffusion.py new file mode 100644 index 0000000..f18a158 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/layered_diffusion.py @@ -0,0 +1,683 @@ +import os +from enum import Enum +import torch +import functools +import copy +from typing import Optional, List +from dataclasses import dataclass + +import folder_paths +import comfy.model_management +import comfy.model_base +import comfy.supported_models +import comfy.supported_models_base +from comfy.model_patcher import ModelPatcher +from folder_paths import get_folder_paths +from comfy.utils import load_torch_file +from comfy_extras.nodes_compositing import JoinImageWithAlpha +from comfy.conds import CONDRegular +from .lib_layerdiffusion.utils import ( + load_file_from_url, + to_lora_patch_dict, +) +from .lib_layerdiffusion.models import TransparentVAEDecoder +from .lib_layerdiffusion.attention_sharing import AttentionSharingPatcher +from .lib_layerdiffusion.enums import StableDiffusionVersion + +if "layer_model" in folder_paths.folder_names_and_paths: + layer_model_root = get_folder_paths("layer_model")[0] +else: + layer_model_root = os.path.join(folder_paths.models_dir, "layer_model") +load_layer_model_state_dict = load_torch_file + + +# ------------ Start patching ComfyUI ------------ +def calculate_weight_adjust_channel(func): + """Patches ComfyUI's LoRA weight application to accept multi-channel inputs.""" + + @functools.wraps(func) + def calculate_weight( + self: ModelPatcher, patches, weight: torch.Tensor, key: str + ) -> torch.Tensor: + weight = func(self, patches, weight, key) + + for p in patches: + alpha = p[0] + v = p[1] + + # The recursion call should be handled in the main func call. + if isinstance(v, list): + continue + + if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": + w1 = v[0] + if all( + ( + alpha != 0.0, + w1.shape != weight.shape, + w1.ndim == weight.ndim == 4, + ) + ): + new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] + print( + f"Merged with {key} channel changed from {weight.shape} to {new_shape}" + ) + new_diff = alpha * comfy.model_management.cast_to_device( + w1, weight.device, weight.dtype + ) + new_weight = torch.zeros(size=new_shape).to(weight) + new_weight[ + : weight.shape[0], + : weight.shape[1], + : weight.shape[2], + : weight.shape[3], + ] = weight + new_weight[ + : new_diff.shape[0], + : new_diff.shape[1], + : new_diff.shape[2], + : new_diff.shape[3], + ] += new_diff + new_weight = new_weight.contiguous().clone() + weight = new_weight + return weight + + return calculate_weight + + +ModelPatcher.calculate_weight = calculate_weight_adjust_channel( + ModelPatcher.calculate_weight +) + +# ------------ End patching ComfyUI ------------ + + +class LayeredDiffusionDecode: + """ + Decode alpha channel value from pixel value. + [B, C=3, H, W] => [B, C=4, H, W] + Outputs RGB image + Alpha mask. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "images": ("IMAGE",), + "sd_version": ( + [ + StableDiffusionVersion.SD1x.value, + StableDiffusionVersion.SDXL.value, + ], + { + "default": StableDiffusionVersion.SDXL.value, + }, + ), + "sub_batch_size": ( + "INT", + {"default": 16, "min": 1, "max": 4096, "step": 1}, + ), + }, + } + + RETURN_TYPES = ("IMAGE", "MASK") + FUNCTION = "decode" + CATEGORY = "layer_diffuse" + + def __init__(self) -> None: + self.vae_transparent_decoder = {} + + def decode(self, samples, images, sd_version: str, sub_batch_size: int): + """ + sub_batch_size: How many images to decode in a single pass. + See https://github.com/huchenlei/ComfyUI-layerdiffuse/pull/4 for more + context. + """ + sd_version = StableDiffusionVersion(sd_version) + if sd_version == StableDiffusionVersion.SD1x: + url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors" + file_name = "layer_sd15_vae_transparent_decoder.safetensors" + elif sd_version == StableDiffusionVersion.SDXL: + url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors" + file_name = "vae_transparent_decoder.safetensors" + + if not self.vae_transparent_decoder.get(sd_version): + model_path = load_file_from_url( + url=url, model_dir=layer_model_root, file_name=file_name + ) + self.vae_transparent_decoder[sd_version] = TransparentVAEDecoder( + load_torch_file(model_path), + device=comfy.model_management.get_torch_device(), + dtype=( + torch.float16 + if comfy.model_management.should_use_fp16() + else torch.float32 + ), + ) + pixel = images.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W] + + # Decoder requires dimension to be 64-aligned. + B, C, H, W = pixel.shape + assert H % 64 == 0, f"Height({H}) is not multiple of 64." + assert W % 64 == 0, f"Height({W}) is not multiple of 64." + + decoded = [] + for start_idx in range(0, samples["samples"].shape[0], sub_batch_size): + decoded.append( + self.vae_transparent_decoder[sd_version].decode_pixel( + pixel[start_idx : start_idx + sub_batch_size], + samples["samples"][start_idx : start_idx + sub_batch_size], + ) + ) + pixel_with_alpha = torch.cat(decoded, dim=0) + + # [B, C, H, W] => [B, H, W, C] + pixel_with_alpha = pixel_with_alpha.movedim(1, -1) + image = pixel_with_alpha[..., 1:] + alpha = pixel_with_alpha[..., 0] + return (image, alpha) + + +class LayeredDiffusionDecodeRGBA(LayeredDiffusionDecode): + """ + Decode alpha channel value from pixel value. + [B, C=3, H, W] => [B, C=4, H, W] + Outputs RGBA image. + """ + + RETURN_TYPES = ("IMAGE",) + + def decode(self, samples, images, sd_version: str, sub_batch_size: int): + image, mask = super().decode(samples, images, sd_version, sub_batch_size) + alpha = 1.0 - mask + return JoinImageWithAlpha().join_image_with_alpha(image, alpha) + + +class LayeredDiffusionDecodeSplit(LayeredDiffusionDecodeRGBA): + """Decode RGBA every N images.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "images": ("IMAGE",), + # Do RGBA decode every N output images. + "frames": ( + "INT", + {"default": 2, "min": 2, "max": s.MAX_FRAMES, "step": 1}, + ), + "sd_version": ( + [ + StableDiffusionVersion.SD1x.value, + StableDiffusionVersion.SDXL.value, + ], + { + "default": StableDiffusionVersion.SDXL.value, + }, + ), + "sub_batch_size": ( + "INT", + {"default": 16, "min": 1, "max": 4096, "step": 1}, + ), + }, + } + + MAX_FRAMES = 3 + RETURN_TYPES = ("IMAGE",) * MAX_FRAMES + + def decode( + self, + samples, + images: torch.Tensor, + frames: int, + sd_version: str, + sub_batch_size: int, + ): + sliced_samples = copy.copy(samples) + sliced_samples["samples"] = sliced_samples["samples"][::frames] + return tuple( + ( + ( + super(LayeredDiffusionDecodeSplit, self).decode( + sliced_samples, imgs, sd_version, sub_batch_size + )[0] + if i == 0 + else imgs + ) + for i in range(frames) + for imgs in (images[i::frames],) + ) + ) + (None,) * (self.MAX_FRAMES - frames) + + +class LayerMethod(Enum): + ATTN = "Attention Injection" + CONV = "Conv Injection" + + +class LayerType(Enum): + FG = "Foreground" + BG = "Background" + + +@dataclass +class LayeredDiffusionBase: + model_file_name: str + model_url: str + sd_version: StableDiffusionVersion + attn_sharing: bool = False + injection_method: Optional[LayerMethod] = None + cond_type: Optional[LayerType] = None + # Number of output images per run. + frames: int = 1 + + @property + def config_string(self) -> str: + injection_method = self.injection_method.value if self.injection_method else "" + cond_type = self.cond_type.value if self.cond_type else "" + attn_sharing = "attn_sharing" if self.attn_sharing else "" + frames = f"Batch size ({self.frames}N)" if self.frames != 1 else "" + return ", ".join( + x + for x in ( + self.sd_version.value, + injection_method, + cond_type, + attn_sharing, + frames, + ) + if x + ) + + def apply_c_concat(self, cond, uncond, c_concat): + """Set foreground/background concat condition.""" + + def write_c_concat(cond): + new_cond = [] + for t in cond: + n = [t[0], t[1].copy()] + if "model_conds" not in n[1]: + n[1]["model_conds"] = {} + n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat) + new_cond.append(n) + return new_cond + + return (write_c_concat(cond), write_c_concat(uncond)) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + weight: float, + ): + """Patch model""" + model_path = load_file_from_url( + url=self.model_url, + model_dir=layer_model_root, + file_name=self.model_file_name, + ) + layer_lora_state_dict = load_layer_model_state_dict(model_path) + layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict) + work_model = model.clone() + work_model.add_patches(layer_lora_patch_dict, weight) + return (work_model,) + + def apply_layered_diffusion_attn_sharing( + self, + model: ModelPatcher, + control_img: Optional[torch.TensorType] = None, + ): + """Patch model with attn sharing""" + model_path = load_file_from_url( + url=self.model_url, + model_dir=layer_model_root, + file_name=self.model_file_name, + ) + layer_lora_state_dict = load_layer_model_state_dict(model_path) + work_model = model.clone() + patcher = AttentionSharingPatcher( + work_model, self.frames, use_control=control_img is not None + ) + patcher.load_state_dict(layer_lora_state_dict, strict=True) + if control_img is not None: + patcher.set_control(control_img) + return (work_model,) + + +def get_model_sd_version(model: ModelPatcher) -> StableDiffusionVersion: + """Get model's StableDiffusionVersion.""" + base: comfy.model_base.BaseModel = model.model + model_config: comfy.supported_models.supported_models_base.BASE = base.model_config + if isinstance(model_config, comfy.supported_models.SDXL): + return StableDiffusionVersion.SDXL + elif isinstance( + model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20) + ): + # SD15 and SD20 are compatible with each other. + return StableDiffusionVersion.SD1x + else: + raise Exception(f"Unsupported SD Version: {type(model_config)}.") + + +class LayeredDiffusionFG: + """Generate foreground with transparent background.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "config": ([c.config_string for c in s.MODELS],), + "weight": ( + "FLOAT", + {"default": 1.0, "min": -1, "max": 3, "step": 0.05}, + ), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_layered_diffusion" + CATEGORY = "layer_diffuse" + MODELS = ( + LayeredDiffusionBase( + model_file_name="layer_xl_transparent_attn.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors", + sd_version=StableDiffusionVersion.SDXL, + injection_method=LayerMethod.ATTN, + ), + LayeredDiffusionBase( + model_file_name="layer_xl_transparent_conv.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors", + sd_version=StableDiffusionVersion.SDXL, + injection_method=LayerMethod.CONV, + ), + LayeredDiffusionBase( + model_file_name="layer_sd15_transparent_attn.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors", + sd_version=StableDiffusionVersion.SD1x, + injection_method=LayerMethod.ATTN, + attn_sharing=True, + ), + ) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + config: str, + weight: float, + ): + ld_model = [m for m in self.MODELS if m.config_string == config][0] + assert get_model_sd_version(model) == ld_model.sd_version + if ld_model.attn_sharing: + return ld_model.apply_layered_diffusion_attn_sharing(model) + else: + return ld_model.apply_layered_diffusion(model, weight) + + +class LayeredDiffusionJoint: + """Generate FG + BG + Blended in one inference batch. Batch size = 3N.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "config": ([c.config_string for c in s.MODELS],), + }, + "optional": { + "fg_cond": ("CONDITIONING",), + "bg_cond": ("CONDITIONING",), + "blended_cond": ("CONDITIONING",), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_layered_diffusion" + CATEGORY = "layer_diffuse" + MODELS = ( + LayeredDiffusionBase( + model_file_name="layer_sd15_joint.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors", + sd_version=StableDiffusionVersion.SD1x, + attn_sharing=True, + frames=3, + ), + ) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + config: str, + fg_cond: Optional[List[List[torch.TensorType]]] = None, + bg_cond: Optional[List[List[torch.TensorType]]] = None, + blended_cond: Optional[List[List[torch.TensorType]]] = None, + ): + ld_model = [m for m in self.MODELS if m.config_string == config][0] + assert get_model_sd_version(model) == ld_model.sd_version + assert ld_model.attn_sharing + work_model = ld_model.apply_layered_diffusion_attn_sharing(model)[0] + work_model.model_options.setdefault("transformer_options", {}) + work_model.model_options["transformer_options"]["cond_overwrite"] = [ + cond[0][0] if cond is not None else None + for cond in ( + fg_cond, + bg_cond, + blended_cond, + ) + ] + return (work_model,) + + +class LayeredDiffusionCond: + """Generate foreground + background given background / foreground. + - FG => Blended + - BG => Blended + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "cond": ("CONDITIONING",), + "uncond": ("CONDITIONING",), + "latent": ("LATENT",), + "config": ([c.config_string for c in s.MODELS],), + "weight": ( + "FLOAT", + {"default": 1.0, "min": -1, "max": 3, "step": 0.05}, + ), + }, + } + + RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING") + FUNCTION = "apply_layered_diffusion" + CATEGORY = "layer_diffuse" + MODELS = ( + LayeredDiffusionBase( + model_file_name="layer_xl_fg2ble.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors", + sd_version=StableDiffusionVersion.SDXL, + cond_type=LayerType.FG, + ), + LayeredDiffusionBase( + model_file_name="layer_xl_bg2ble.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors", + sd_version=StableDiffusionVersion.SDXL, + cond_type=LayerType.BG, + ), + ) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + cond, + uncond, + latent, + config: str, + weight: float, + ): + ld_model = [m for m in self.MODELS if m.config_string == config][0] + assert get_model_sd_version(model) == ld_model.sd_version + c_concat = model.model.latent_format.process_in(latent["samples"]) + return ld_model.apply_layered_diffusion( + model, weight + ) + ld_model.apply_c_concat(cond, uncond, c_concat) + + +class LayeredDiffusionCondJoint: + """Generate fg/bg + blended given fg/bg. + - FG => Blended + BG + - BG => Blended + FG + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "image": ("IMAGE",), + "config": ([c.config_string for c in s.MODELS],), + }, + "optional": { + "cond": ("CONDITIONING",), + "blended_cond": ("CONDITIONING",), + }, + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_layered_diffusion" + CATEGORY = "layer_diffuse" + MODELS = ( + LayeredDiffusionBase( + model_file_name="layer_sd15_fg2bg.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors", + sd_version=StableDiffusionVersion.SD1x, + attn_sharing=True, + frames=2, + cond_type=LayerType.FG, + ), + LayeredDiffusionBase( + model_file_name="layer_sd15_bg2fg.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors", + sd_version=StableDiffusionVersion.SD1x, + attn_sharing=True, + frames=2, + cond_type=LayerType.BG, + ), + ) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + image, + config: str, + cond: Optional[List[List[torch.TensorType]]] = None, + blended_cond: Optional[List[List[torch.TensorType]]] = None, + ): + ld_model = [m for m in self.MODELS if m.config_string == config][0] + assert get_model_sd_version(model) == ld_model.sd_version + assert ld_model.attn_sharing + work_model = ld_model.apply_layered_diffusion_attn_sharing( + model, control_img=image.movedim(-1, 1) + )[0] + work_model.model_options.setdefault("transformer_options", {}) + work_model.model_options["transformer_options"]["cond_overwrite"] = [ + cond[0][0] if cond is not None else None + for cond in ( + cond, + blended_cond, + ) + ] + return (work_model,) + + +class LayeredDiffusionDiff: + """Extract FG/BG from blended image. + - Blended + FG => BG + - Blended + BG => FG + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "cond": ("CONDITIONING",), + "uncond": ("CONDITIONING",), + "blended_latent": ("LATENT",), + "latent": ("LATENT",), + "config": ([c.config_string for c in s.MODELS],), + "weight": ( + "FLOAT", + {"default": 1.0, "min": -1, "max": 3, "step": 0.05}, + ), + }, + } + + RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING") + FUNCTION = "apply_layered_diffusion" + CATEGORY = "layer_diffuse" + MODELS = ( + LayeredDiffusionBase( + model_file_name="layer_xl_fgble2bg.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors", + sd_version=StableDiffusionVersion.SDXL, + cond_type=LayerType.FG, + ), + LayeredDiffusionBase( + model_file_name="layer_xl_bgble2fg.safetensors", + model_url="https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors", + sd_version=StableDiffusionVersion.SDXL, + cond_type=LayerType.BG, + ), + ) + + def apply_layered_diffusion( + self, + model: ModelPatcher, + cond, + uncond, + blended_latent, + latent, + config: str, + weight: float, + ): + ld_model = [m for m in self.MODELS if m.config_string == config][0] + assert get_model_sd_version(model) == ld_model.sd_version + c_concat = model.model.latent_format.process_in( + torch.cat([latent["samples"], blended_latent["samples"]], dim=1) + ) + return ld_model.apply_layered_diffusion( + model, weight + ) + ld_model.apply_c_concat(cond, uncond, c_concat) + + +NODE_CLASS_MAPPINGS = { + "LayeredDiffusionApply": LayeredDiffusionFG, + "LayeredDiffusionJointApply": LayeredDiffusionJoint, + "LayeredDiffusionCondApply": LayeredDiffusionCond, + "LayeredDiffusionCondJointApply": LayeredDiffusionCondJoint, + "LayeredDiffusionDiffApply": LayeredDiffusionDiff, + "LayeredDiffusionDecode": LayeredDiffusionDecode, + "LayeredDiffusionDecodeRGBA": LayeredDiffusionDecodeRGBA, + "LayeredDiffusionDecodeSplit": LayeredDiffusionDecodeSplit, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LayeredDiffusionApply": "Layer Diffuse Apply", + "LayeredDiffusionJointApply": "Layer Diffuse Joint Apply", + "LayeredDiffusionCondApply": "Layer Diffuse Cond Apply", + "LayeredDiffusionCondJointApply": "Layer Diffuse Cond Joint Apply", + "LayeredDiffusionDiffApply": "Layer Diffuse Diff Apply", + "LayeredDiffusionDecode": "Layer Diffuse Decode", + "LayeredDiffusionDecodeRGBA": "Layer Diffuse Decode (RGBA)", + "LayeredDiffusionDecodeSplit": "Layer Diffuse Decode (Split)", +} diff --git a/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/__init__.py b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/attention_sharing.py b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/attention_sharing.py new file mode 100644 index 0000000..ae424fc --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/attention_sharing.py @@ -0,0 +1,360 @@ +# Currently only sd15 + +import functools +import torch +import einops + +from comfy import model_management, utils +from comfy.ldm.modules.attention import optimized_attention + + +module_mapping_sd15 = { + 0: "input_blocks.1.1.transformer_blocks.0.attn1", + 1: "input_blocks.1.1.transformer_blocks.0.attn2", + 2: "input_blocks.2.1.transformer_blocks.0.attn1", + 3: "input_blocks.2.1.transformer_blocks.0.attn2", + 4: "input_blocks.4.1.transformer_blocks.0.attn1", + 5: "input_blocks.4.1.transformer_blocks.0.attn2", + 6: "input_blocks.5.1.transformer_blocks.0.attn1", + 7: "input_blocks.5.1.transformer_blocks.0.attn2", + 8: "input_blocks.7.1.transformer_blocks.0.attn1", + 9: "input_blocks.7.1.transformer_blocks.0.attn2", + 10: "input_blocks.8.1.transformer_blocks.0.attn1", + 11: "input_blocks.8.1.transformer_blocks.0.attn2", + 12: "output_blocks.3.1.transformer_blocks.0.attn1", + 13: "output_blocks.3.1.transformer_blocks.0.attn2", + 14: "output_blocks.4.1.transformer_blocks.0.attn1", + 15: "output_blocks.4.1.transformer_blocks.0.attn2", + 16: "output_blocks.5.1.transformer_blocks.0.attn1", + 17: "output_blocks.5.1.transformer_blocks.0.attn2", + 18: "output_blocks.6.1.transformer_blocks.0.attn1", + 19: "output_blocks.6.1.transformer_blocks.0.attn2", + 20: "output_blocks.7.1.transformer_blocks.0.attn1", + 21: "output_blocks.7.1.transformer_blocks.0.attn2", + 22: "output_blocks.8.1.transformer_blocks.0.attn1", + 23: "output_blocks.8.1.transformer_blocks.0.attn2", + 24: "output_blocks.9.1.transformer_blocks.0.attn1", + 25: "output_blocks.9.1.transformer_blocks.0.attn2", + 26: "output_blocks.10.1.transformer_blocks.0.attn1", + 27: "output_blocks.10.1.transformer_blocks.0.attn2", + 28: "output_blocks.11.1.transformer_blocks.0.attn1", + 29: "output_blocks.11.1.transformer_blocks.0.attn2", + 30: "middle_block.1.transformer_blocks.0.attn1", + 31: "middle_block.1.transformer_blocks.0.attn2", +} + + +def compute_cond_mark(cond_or_uncond, sigmas): + cond_or_uncond_size = int(sigmas.shape[0]) + + cond_mark = [] + for cx in cond_or_uncond: + cond_mark += [cx] * cond_or_uncond_size + + cond_mark = torch.Tensor(cond_mark).to(sigmas) + return cond_mark + + +class LoRALinearLayer(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None): + super().__init__() + self.down = torch.nn.Linear(in_features, rank, bias=False) + self.up = torch.nn.Linear(rank, out_features, bias=False) + self.org = [org] + + def forward(self, h): + org_weight = self.org[0].weight.to(h) + org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None + down_weight = self.down.weight + up_weight = self.up.weight + final_weight = org_weight + torch.mm(up_weight, down_weight) + return torch.nn.functional.linear(h, final_weight, org_bias) + + +class AttentionSharingUnit(torch.nn.Module): + # `transformer_options` passed to the most recent BasicTransformerBlock.forward + # call. + transformer_options: dict = {} + + def __init__(self, module, frames=2, use_control=True, rank=256): + super().__init__() + + self.heads = module.heads + self.frames = frames + self.original_module = [module] + q_in_channels, q_out_channels = ( + module.to_q.in_features, + module.to_q.out_features, + ) + k_in_channels, k_out_channels = ( + module.to_k.in_features, + module.to_k.out_features, + ) + v_in_channels, v_out_channels = ( + module.to_v.in_features, + module.to_v.out_features, + ) + o_in_channels, o_out_channels = ( + module.to_out[0].in_features, + module.to_out[0].out_features, + ) + + hidden_size = k_out_channels + + self.to_q_lora = [ + LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q) + for _ in range(self.frames) + ] + self.to_k_lora = [ + LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k) + for _ in range(self.frames) + ] + self.to_v_lora = [ + LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v) + for _ in range(self.frames) + ] + self.to_out_lora = [ + LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0]) + for _ in range(self.frames) + ] + + self.to_q_lora = torch.nn.ModuleList(self.to_q_lora) + self.to_k_lora = torch.nn.ModuleList(self.to_k_lora) + self.to_v_lora = torch.nn.ModuleList(self.to_v_lora) + self.to_out_lora = torch.nn.ModuleList(self.to_out_lora) + + self.temporal_i = torch.nn.Linear( + in_features=hidden_size, out_features=hidden_size + ) + self.temporal_n = torch.nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6 + ) + self.temporal_q = torch.nn.Linear( + in_features=hidden_size, out_features=hidden_size + ) + self.temporal_k = torch.nn.Linear( + in_features=hidden_size, out_features=hidden_size + ) + self.temporal_v = torch.nn.Linear( + in_features=hidden_size, out_features=hidden_size + ) + self.temporal_o = torch.nn.Linear( + in_features=hidden_size, out_features=hidden_size + ) + + self.control_convs = None + + if use_control: + self.control_convs = [ + torch.nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + torch.nn.Conv2d(256, hidden_size, kernel_size=1), + ) + for _ in range(self.frames) + ] + self.control_convs = torch.nn.ModuleList(self.control_convs) + + self.control_signals = None + + def forward(self, h, context=None, value=None): + transformer_options = self.transformer_options + + modified_hidden_states = einops.rearrange( + h, "(b f) d c -> f b d c", f=self.frames + ) + + if self.control_convs is not None: + context_dim = int(modified_hidden_states.shape[2]) + control_outs = [] + for f in range(self.frames): + control_signal = self.control_signals[context_dim].to( + modified_hidden_states + ) + control = self.control_convs[f](control_signal) + control = einops.rearrange(control, "b c h w -> b (h w) c") + control_outs.append(control) + control_outs = torch.stack(control_outs, dim=0) + modified_hidden_states = modified_hidden_states + control_outs.to( + modified_hidden_states + ) + + if context is None: + framed_context = modified_hidden_states + else: + framed_context = einops.rearrange( + context, "(b f) d c -> f b d c", f=self.frames + ) + + framed_cond_mark = einops.rearrange( + compute_cond_mark( + transformer_options["cond_or_uncond"], + transformer_options["sigmas"], + ), + "(b f) -> f b", + f=self.frames, + ).to(modified_hidden_states) + + attn_outs = [] + for f in range(self.frames): + fcf = framed_context[f] + + if context is not None: + cond_overwrite = transformer_options.get("cond_overwrite", []) + if len(cond_overwrite) > f: + cond_overwrite = cond_overwrite[f] + else: + cond_overwrite = None + if cond_overwrite is not None: + cond_mark = framed_cond_mark[f][:, None, None] + fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark + + q = self.to_q_lora[f](modified_hidden_states[f]) + k = self.to_k_lora[f](fcf) + v = self.to_v_lora[f](fcf) + o = optimized_attention(q, k, v, self.heads) + o = self.to_out_lora[f](o) + o = self.original_module[0].to_out[1](o) + attn_outs.append(o) + + attn_outs = torch.stack(attn_outs, dim=0) + modified_hidden_states = modified_hidden_states + attn_outs.to( + modified_hidden_states + ) + modified_hidden_states = einops.rearrange( + modified_hidden_states, "f b d c -> (b f) d c", f=self.frames + ) + + x = modified_hidden_states + x = self.temporal_n(x) + x = self.temporal_i(x) + d = x.shape[1] + + x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames) + + q = self.temporal_q(x) + k = self.temporal_k(x) + v = self.temporal_v(x) + + x = optimized_attention(q, k, v, self.heads) + x = self.temporal_o(x) + x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d) + + modified_hidden_states = modified_hidden_states + x + + return modified_hidden_states - h + + @classmethod + def hijack_transformer_block(cls): + def register_get_transformer_options(func): + @functools.wraps(func) + def forward(self, x, context=None, transformer_options={}): + cls.transformer_options = transformer_options + return func(self, x, context, transformer_options) + + return forward + + from comfy.ldm.modules.attention import BasicTransformerBlock + + BasicTransformerBlock.forward = register_get_transformer_options( + BasicTransformerBlock.forward + ) + + +AttentionSharingUnit.hijack_transformer_block() + + +class AdditionalAttentionCondsEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + + self.blocks_0 = torch.nn.Sequential( + torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + ) # 64*64*256 + + self.blocks_1 = torch.nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + ) # 32*32*256 + + self.blocks_2 = torch.nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + ) # 16*16*256 + + self.blocks_3 = torch.nn.Sequential( + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2), + torch.nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + torch.nn.SiLU(), + ) # 8*8*256 + + self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3] + + def __call__(self, h): + results = {} + for b in self.blks: + h = b(h) + results[int(h.shape[2]) * int(h.shape[3])] = h + return results + + +class HookerLayers(torch.nn.Module): + def __init__(self, layer_list): + super().__init__() + self.layers = torch.nn.ModuleList(layer_list) + + +class AttentionSharingPatcher(torch.nn.Module): + def __init__(self, unet, frames=2, use_control=True, rank=256): + super().__init__() + model_management.unload_model_clones(unet) + + units = [] + for i in range(32): + real_key = module_mapping_sd15[i] + attn_module = utils.get_attr(unet.model.diffusion_model, real_key) + u = AttentionSharingUnit( + attn_module, frames=frames, use_control=use_control, rank=rank + ) + units.append(u) + unet.add_object_patch("diffusion_model." + real_key, u) + + self.hookers = HookerLayers(units) + + if use_control: + self.kwargs_encoder = AdditionalAttentionCondsEncoder() + else: + self.kwargs_encoder = None + + self.dtype = torch.float32 + if model_management.should_use_fp16(model_management.get_torch_device()): + self.dtype = torch.float16 + self.hookers.half() + return + + def set_control(self, img): + img = img.cpu().float() * 2.0 - 1.0 + signals = self.kwargs_encoder(img) + for m in self.hookers.layers: + m.control_signals = signals + return diff --git a/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/enums.py b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/enums.py new file mode 100644 index 0000000..c4623e3 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/enums.py @@ -0,0 +1,23 @@ +from enum import Enum + + +class ResizeMode(Enum): + RESIZE = "Just Resize" + CROP_AND_RESIZE = "Crop and Resize" + RESIZE_AND_FILL = "Resize and Fill" + + def int_value(self): + if self == ResizeMode.RESIZE: + return 0 + elif self == ResizeMode.CROP_AND_RESIZE: + return 1 + elif self == ResizeMode.RESIZE_AND_FILL: + return 2 + return 0 + + +class StableDiffusionVersion(Enum): + """The version family of stable diffusion model.""" + + SD1x = "SD15" + SDXL = "SDXL" diff --git a/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/models.py b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/models.py new file mode 100644 index 0000000..8496be8 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/models.py @@ -0,0 +1,318 @@ +import torch.nn as nn +import torch +import cv2 +import numpy as np + +from tqdm import tqdm +from typing import Optional, Tuple +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +def check_diffusers_version(): + import diffusers + from packaging.version import parse + + assert parse(diffusers.__version__) >= parse( + "0.25.0" + ), "diffusers>=0.25.0 requirement not satisfied. Please install correct diffusers version." + + +check_diffusers_version() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class LatentTransparencyOffsetEncoder(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.blocks = torch.nn.Sequential( + torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), + nn.SiLU(), + torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1), + nn.SiLU(), + zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)), + ) + + def __call__(self, x): + return self.blocks(x) + + +# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3 +class UNet1024(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ), + up_block_types: Tuple[str] = ( + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: Optional[int] = 8, + norm_num_groups: int = 4, + norm_eps: float = 1e-5, + ): + super().__init__() + + # input + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) + self.latent_conv_in = zero_module( + nn.Conv2d(4, block_out_channels[2], kernel_size=1) + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=None, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=( + attention_head_dim + if attention_head_dim is not None + else output_channel + ), + downsample_padding=downsample_padding, + resnet_time_scale_shift="default", + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=None, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + attention_head_dim=( + attention_head_dim + if attention_head_dim is not None + else block_out_channels[-1] + ), + resnet_groups=norm_num_groups, + attn_groups=None, + add_attention=True, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=None, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=( + attention_head_dim + if attention_head_dim is not None + else output_channel + ), + resnet_time_scale_shift="default", + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) + + def forward(self, x, latent): + sample_latent = self.latent_conv_in(latent) + sample = self.conv_in(x) + emb = None + + down_block_res_samples = (sample,) + for i, downsample_block in enumerate(self.down_blocks): + if i == 3: + sample = sample + sample_latent + + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples + + sample = self.mid_block(sample, emb) + + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + sample = upsample_block(sample, res_samples, emb) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +def checkerboard(shape): + return np.indices(shape).sum(axis=0) % 2 + + +def fill_checkerboard_bg(y: torch.Tensor) -> torch.Tensor: + alpha = y[..., :1] + fg = y[..., 1:] + B, H, W, C = fg.shape + cb = checkerboard(shape=(H // 64, W // 64)) + cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST) + cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None] + cb = torch.from_numpy(cb).to(fg) + vis = fg * alpha + cb * (1 - alpha) + return vis + + +class TransparentVAEDecoder: + def __init__(self, sd, device, dtype): + self.load_device = device + self.dtype = dtype + + model = UNet1024(in_channels=3, out_channels=4) + model.load_state_dict(sd, strict=True) + model.to(self.load_device, dtype=self.dtype) + model.eval() + self.model = model + + @torch.no_grad() + def estimate_single_pass(self, pixel, latent): + y = self.model(pixel, latent) + return y + + @torch.no_grad() + def estimate_augmented(self, pixel, latent): + args = [ + [False, 0], + [False, 1], + [False, 2], + [False, 3], + [True, 0], + [True, 1], + [True, 2], + [True, 3], + ] + + result = [] + + for flip, rok in tqdm(args): + feed_pixel = pixel.clone() + feed_latent = latent.clone() + + if flip: + feed_pixel = torch.flip(feed_pixel, dims=(3,)) + feed_latent = torch.flip(feed_latent, dims=(3,)) + + feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) + feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) + + eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1) + eps = torch.rot90(eps, k=-rok, dims=(2, 3)) + + if flip: + eps = torch.flip(eps, dims=(3,)) + + result += [eps] + + result = torch.stack(result, dim=0) + median = torch.median(result, dim=0).values + return median + + @torch.no_grad() + def decode_pixel( + self, pixel: torch.TensorType, latent: torch.TensorType + ) -> torch.TensorType: + # pixel.shape = [B, C=3, H, W] + assert pixel.shape[1] == 3 + pixel_device = pixel.device + pixel_dtype = pixel.dtype + + pixel = pixel.to(device=self.load_device, dtype=self.dtype) + latent = latent.to(device=self.load_device, dtype=self.dtype) + # y.shape = [B, C=4, H, W] + y = self.estimate_augmented(pixel, latent) + y = y.clip(0, 1) + assert y.shape[1] == 4 + # Restore image to original device of input image. + return y.to(pixel_device, dtype=pixel_dtype) diff --git a/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/utils.py b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/utils.py new file mode 100644 index 0000000..fa91341 --- /dev/null +++ b/src/inference_core_nodes/layer_diffuse/lib_layerdiffusion/utils.py @@ -0,0 +1,135 @@ +import numpy as np +from .enums import ResizeMode +import cv2 +import torch +import os +from urllib.parse import urlparse +from typing import Optional + + +def rgba2rgbfp32(x): + rgb = x[..., :3].astype(np.float32) / 255.0 + a = x[..., 3:4].astype(np.float32) / 255.0 + return 0.5 + (rgb - 0.5) * a + + +def to255unit8(x): + return (x * 255.0).clip(0, 255).astype(np.uint8) + + +def safe_numpy(x): + # A very safe method to make sure that Apple/Mac works + y = x + + # below is very boring but do not change these. If you change these Apple or Mac may fail. + y = y.copy() + y = np.ascontiguousarray(y) + y = y.copy() + return y + + +def high_quality_resize(x, size): + if x.shape[0] != size[1] or x.shape[1] != size[0]: + if (size[0] * size[1]) < (x.shape[0] * x.shape[1]): + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + + y = cv2.resize(x, size, interpolation=interpolation) + else: + y = x + return y + + +def crop_and_resize_image(detected_map, resize_mode, h, w): + if resize_mode == ResizeMode.RESIZE: + detected_map = high_quality_resize(detected_map, (w, h)) + detected_map = safe_numpy(detected_map) + return detected_map + + old_h, old_w, _ = detected_map.shape + old_w = float(old_w) + old_h = float(old_h) + k0 = float(h) / old_h + k1 = float(w) / old_w + + def safeint(x): + return int(np.round(x)) + + if resize_mode == ResizeMode.RESIZE_AND_FILL: + k = min(k0, k1) + borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) + high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) + high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) + detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = detected_map.shape + pad_h = max(0, (h - new_h) // 2) + pad_w = max(0, (w - new_w) // 2) + high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map + detected_map = high_quality_background + detected_map = safe_numpy(detected_map) + return detected_map + else: + k = max(k0, k1) + detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) + new_h, new_w, _ = detected_map.shape + pad_h = max(0, (new_h - h) // 2) + pad_w = max(0, (new_w - w) // 2) + detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w] + detected_map = safe_numpy(detected_map) + return detected_map + + +def pytorch_to_numpy(x): + return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x] + + +def numpy_to_pytorch(x): + y = x.astype(np.float32) / 255.0 + y = y[None] + y = np.ascontiguousarray(y.copy()) + y = torch.from_numpy(y).float() + return y + + +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: Optional[str] = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + +def to_lora_patch_dict(state_dict: dict) -> dict: + """ Convert raw lora state_dict to patch_dict that can be applied on + modelpatcher.""" + patch_dict = {} + for k, w in state_dict.items(): + model_key, patch_type, weight_index = k.split('::') + if model_key not in patch_dict: + patch_dict[model_key] = {} + if patch_type not in patch_dict[model_key]: + patch_dict[model_key][patch_type] = [None] * 16 + patch_dict[model_key][patch_type][int(weight_index)] = w + + patch_flat = {} + for model_key, v in patch_dict.items(): + for patch_type, weight_list in v.items(): + patch_flat[model_key] = (patch_type, weight_list) + + return patch_flat From b9226a7c3cd8ad2d6034be9d5e1b2c77b24ed8b9 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 24 Mar 2024 20:38:20 -0400 Subject: [PATCH 2/4] Version bump --- src/inference_core_nodes/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference_core_nodes/__init__.py b/src/inference_core_nodes/__init__.py index 2e0bb78..f032acf 100644 --- a/src/inference_core_nodes/__init__.py +++ b/src/inference_core_nodes/__init__.py @@ -1,6 +1,6 @@ __all__ = ("__version__", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS") -__version__ = "0.3.0" +__version__ = "0.4.0" def _get_node_mappings(): From 24f0480a1ff9df69a8dbf7c3c7d0eb9ba3b60cc5 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 24 Mar 2024 20:46:25 -0400 Subject: [PATCH 3/4] Fix onnxruntime dependency urls for pip --- pyproject.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ee93c6c..216253d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,17 +35,17 @@ cuda = [ ] cuda-12 = [ # CUDA 12, Python 3.10, Windows - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp310-cp310-win_amd64.whl/content ; platform_system == 'Windows' and python_version == '3.10'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # CUDA 12, Python 3.10, Linux - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-manylinux_2_28_x86_64.whl/content ; platform_system == 'Linux' and python_version == '3.10'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10'", # CUDA 12, Python 3.11, Windows - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-win_amd64.whl/content ; platform_system == 'Windows' and python_version == '3.11'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", # CUDA 12, Python 3.11, Linux - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-manylinux_2_28_x86_64.whl/content ; platform_system == 'Linux' and python_version == '3.11'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11'", # CUDA 12, Python 3.12, Windows - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp312-cp312-win_amd64.whl/content ; platform_system == 'Windows' and python_version == '3.12'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", # CUDA 12, Python 3.12, Linux - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/PublicPackages/_apis/packaging/feeds/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/packages/onnxruntime-gpu/versions/1.17.1/onnxruntime_gpu-1.17.1-cp312-cp312-manylinux_2_28_x86_64.whl/content ; platform_system == 'Linux' and python_version == '3.12'" + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp312-cp312-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12'" ] directml = [ "onnxruntime-directml" From 46cba06fc9754f07baa7f46372013540d9a2c6a9 Mon Sep 17 00:00:00 2001 From: Ionite Date: Sun, 24 Mar 2024 20:46:58 -0400 Subject: [PATCH 4/4] Fix url --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 216253d..6535d11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ cuda-12 = [ # CUDA 12, Python 3.10, Windows "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # CUDA 12, Python 3.10, Linux - "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10'", + "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp310-cp310-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10'", # CUDA 12, Python 3.11, Windows "onnxruntime-gpu @ https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/9387c3aa-d9ad-4513-968c-383f6f7f53b8/pypi/download/onnxruntime-gpu/1.17.1/onnxruntime_gpu-1.17.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", # CUDA 12, Python 3.11, Linux