-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
87 lines (77 loc) · 2.89 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import shutil
from typing import List
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from pytorch_lightning import seed_everything
from cog import BasePredictor, Input, Path
DREAMBOOTH_MODEL_PATH="weights/funko-diffusion"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
self.pipe = StableDiffusionPipeline.from_pretrained(
DREAMBOOTH_MODEL_PATH,
scheduler=scheduler,
safety_checker=None,
torch_dtype=torch.float16,
).to("cuda")
@torch.inference_mode()
@torch.cuda.amp.autocast()
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="photography of an Italian car in Tuscany, poolsuite style",
),
width: int = Input(
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
choices=[128, 256, 512, 768, 1024],
default=512,
),
height: int = Input(
description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
choices=[128, 256, 512, 768, 1024],
default=512,
),
num_outputs: int = Input(
description="Number of images to output", choices=[1, 4], default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=6
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
if width == height == 1024:
raise ValueError(
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
)
seed_everything(seed)
output = self.pipe(
prompt=[prompt] * num_outputs,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
output_paths = []
for i, sample in enumerate(output["images"]):
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(Path(output_path))
return output_paths