Skip to content

Commit

Permalink
Support SDXL latents directly
Browse files Browse the repository at this point in the history
Bypassing need for SUPIR VAE
  • Loading branch information
kijai committed Jun 27, 2024
1 parent c257cce commit 0067546
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"SUPIR_encode": "SUPIR Encode",
"SUPIR_decode": "SUPIR Decode",
"SUPIR_conditioner": "SUPIR Conditioner",
"SUPIR_tiles": "SUPIR Tiles",
"SUPIR_tiles": "SUPIR Tiles Preview",
"SUPIR_model_loader_v2": "SUPIR Model Loader (v2)",
"SUPIR_model_loader_v2_clip": "SUPIR Model Loader (v2) (Clip)"
}
Expand Down
55 changes: 40 additions & 15 deletions nodes_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,20 @@ def decode(self, SUPIR_VAE, latents, use_tiled_vae, decoder_tile_size):
device = mm.get_torch_device()
mm.unload_all_models()
samples = latents["samples"]
dtype = SUPIR_VAE.dtype
orig_H, orig_W = latents["original_size"]

B, H, W, C = samples.shape

pbar = comfy.utils.ProgressBar(B)

SUPIR_VAE.to(device)

if mm.should_use_bf16():
print("Decoder using bf16")
dtype = torch.bfloat16
else:
print("Decoder using fp32")
dtype = torch.float32
print("SUPIR decoder using", dtype)

SUPIR_VAE.to(dtype).to(device)
samples = samples.to(device)

if use_tiled_vae:
Expand All @@ -220,14 +226,17 @@ def decode(self, SUPIR_VAE, latents, use_tiled_vae, decoder_tile_size):
autocast_condition = (dtype != torch.float32) and not comfy.model_management.is_device_mps(device)
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
sample = 1.0 / 0.13025 * sample
decoded_image = SUPIR_VAE.decode(sample.unsqueeze(0)).float()
decoded_image = SUPIR_VAE.decode(sample.unsqueeze(0))
out.append(decoded_image)
pbar.update(1)

decoded_out= torch.cat(out, dim=0)
if decoded_out.shape[2] != orig_H or decoded_out.shape[3] != orig_W:
print("Restoring original dimensions: ", orig_W,"x",orig_H)
decoded_out = F.interpolate(decoded_out, size=(orig_H, orig_W), mode="bicubic")
decoded_out= torch.cat(out, dim=0).float()

if "original_size" in latents and latents["original_size"] is not None:
orig_H, orig_W = latents["original_size"]
if decoded_out.shape[2] != orig_H or decoded_out.shape[3] != orig_W:
print("Restoring original dimensions: ", orig_W,"x",orig_H)
decoded_out = F.interpolate(decoded_out, size=(orig_H, orig_W), mode="bicubic")

decoded_out = torch.clip(decoded_out, 0, 1)
decoded_out = decoded_out.cpu().to(torch.float32).permute(0, 2, 3, 1)
Expand Down Expand Up @@ -360,9 +369,9 @@ def INPUT_TYPES(s):
"EDM_s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}),
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}),
"DPMPP_eta": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}),
"control_scale_start": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}),
"control_scale_end": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}),
"restore_cfg": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20.0, "step": 0.05}),
"control_scale_start": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}),
"control_scale_end": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.01}),
"restore_cfg": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20.0, "step": 0.01}),
"keep_model_loaded": ("BOOLEAN", {"default": False}),
"sampler": (
[
Expand Down Expand Up @@ -483,7 +492,8 @@ def sample(self, SUPIR_model, latents, steps, seed, cfg_scale_end, EDM_s_churn,
noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device)
else:
print("Using latent from input")
noised_z = sample.unsqueeze(0) * 0.13025
noised_z = torch.randn_like(sample.unsqueeze(0), device=samples.device)
noised_z += sample.unsqueeze(0)
if len(positive) != len(samples):
print("Tiled sampling")
_samples = self.sampler(denoiser, noised_z, cond=positive, uc=negative, x_center=sample.unsqueeze(0), control_scale=control_scale_end,
Expand Down Expand Up @@ -518,6 +528,9 @@ def sample(self, SUPIR_model, latents, steps, seed, cfg_scale_end, EDM_s_churn,
else:
samples_out_stacked = torch.stack(out, dim=0)

if original_size is None:
samples_out_stacked = samples_out_stacked / 0.13025

return ({"samples":samples_out_stacked, "original_size": original_size},)

class SUPIR_conditioner:
Expand Down Expand Up @@ -555,7 +568,14 @@ def condition(self, SUPIR_model, latents, positive_prompt, negative_prompt, capt

device = mm.get_torch_device()
mm.soft_empty_cache()
samples = latents["samples"]

if "original_size" in latents:
original_size = latents["original_size"]
samples = latents["samples"]
else:
original_size = None
samples = latents["samples"] * 0.13025

N, H, W, C = samples.shape
import copy

Expand Down Expand Up @@ -622,8 +642,13 @@ def condition(self, SUPIR_model, latents, positive_prompt, negative_prompt, capt


SUPIR_model.conditioner.to('cpu')

if "original_size" in latents:
original_size = latents["original_size"]
else:
original_size = None

return ({"cond": c, "original_size":latents["original_size"]}, {"uncond": uc},)
return ({"cond": c, "original_size":original_size}, {"uncond": uc},)

class SUPIR_model_loader:
@classmethod
Expand Down

0 comments on commit 0067546

Please sign in to comment.