forked from jy0205/Pyramid-Flow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp_multigpu_engine.py
134 lines (114 loc) · 5.37 KB
/
app_multigpu_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import sys
import torch
import argparse
from PIL import Image
from diffusers.utils import export_to_video
# Add the project root directory to sys.path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', required=True, type=str, help='Path to the downloaded checkpoint directory')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of GPUs used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of processes used for video training, default=-1 means using all processes.")
parser.add_argument('--prompt', type=str, required=True, help="Text prompt for video generation")
parser.add_argument('--image_path', type=str, help="Path to the input image for image-to-video")
parser.add_argument('--video_guidance_scale', type=float, default=5.0, help="Video guidance scale")
parser.add_argument('--guidance_scale', type=float, default=9.0, help="Guidance scale for text-to-video")
parser.add_argument('--resolution', type=str, default='768p', choices=['768p', '384p'], help="Model resolution")
parser.add_argument('--output_path', type=str, required=True, help="Path to save the generated video")
return parser.parse_args()
def main():
args = get_args()
# Setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should match DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
if args.model_name == "pyramid_flux":
assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \
we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation"
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.resolution == '768p':
width = 1280
height = 768
else:
width = 640
height = 384
try:
if args.task == 't2v':
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=args.guidance_scale,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
elif args.task == 'i2v':
if not args.image_path:
raise ValueError("Image path is required for image-to-video task")
image = Image.open(args.image_path).convert("RGB")
image = image.resize((width, height))
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
except Exception as e:
if rank == 0:
print(f"[ERROR] Error during video generation: {e}")
raise
finally:
torch.distributed.barrier()
if __name__ == "__main__":
main()