Skip to content

Commit

Permalink
[Fix] Fix bug when vae of sd is not set.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Dec 14, 2023
1 parent e126786 commit 5198f31
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions agentlego/tools/utils/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@ def load_sd(model: str = 'runwayml/stable-diffusion-v1-5',
StableDiffusionPipeline)

dtype = torch.float16 if 'cuda' in str(device) else torch.float32
params = {'torch_dtype': dtype}

if variant is not None:
params['variant'] = variant
if vae is not None:
vae = load_or_build_object(
AutoencoderKL.from_pretrained,
vae,
torch_dtype=dtype,
variant=vae_variant,
)
params['vae'] = vae

t2i = load_or_build_object(
StableDiffusionPipeline.from_pretrained,
model,
vae=vae,
torch_dtype=dtype,
variant=variant,
)
t2i = load_or_build_object(StableDiffusionPipeline.from_pretrained, model,
**params)

if controlnet is None:
return t2i.to(device)
Expand All @@ -58,24 +57,22 @@ def load_sdxl(model: str = 'stabilityai/stable-diffusion-xl-base-1.0',
from diffusers import (AutoencoderKL, ControlNetModel,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline)

dtype = torch.float16 if 'cuda' in str(device) else torch.float32
params = {'torch_dtype': dtype}

if variant is not None:
params['variant'] = variant
if vae is not None:
vae = load_or_build_object(
AutoencoderKL.from_pretrained,
vae,
torch_dtype=dtype,
variant=vae_variant,
)
params['vae'] = vae

t2i = load_or_build_object(
StableDiffusionXLPipeline.from_pretrained,
model,
vae=vae,
torch_dtype=dtype,
variant=variant,
)
t2i = load_or_build_object(StableDiffusionXLPipeline.from_pretrained,
model, **params)

if controlnet is None:
return t2i.to(device)
Expand Down

0 comments on commit 5198f31

Please sign in to comment.