Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latte opt #1023

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
189 changes: 135 additions & 54 deletions benchmarks/text_to_video_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def parse_args():
parser.add_argument("--model", type=str, default=MODEL)
parser.add_argument("--ckpt", type=str, default=CKPT)
parser.add_argument("--prompt", type=str, default=PROMPT)
parser.add_argument("--save_graph", action="store_true")
parser.add_argument("--load_graph", action="store_true")
parser.add_argument("--variant", type=str, default=VARIANT)
parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
# parser.add_argument("--sample-method", type=str, default=SAMPLE_METHOD)
Expand Down Expand Up @@ -97,6 +95,9 @@ def parse_args():
type=int,
default=ATTENTION_FP16_SCORE_ACCUM_MAX_M,
)
parser.add_argument("--profile_warmup", action="store_true")
parser.add_argument("--profile_run", action="store_true")
parser.add_argument("--from-hf", action="store_true")
return parser.parse_args()


Expand Down Expand Up @@ -126,19 +127,54 @@ def callback_on_step_end(self, pipe, i, t, callback_kwargs={}):
return callback_kwargs


def main():
args = parse_args()
from contextlib import contextmanager

if os.path.exists(args.model):
model_path = args.model

@contextmanager
def conditional_context(enabled, context_manager):
if enabled:
with context_manager as cm:
yield cm
else:
from huggingface_hub import snapshot_download
yield None

model_path = snapshot_download(repo_id=args.model)

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
_is_form_hf = False


def get_pipeline(args, model_path, device):
global _is_form_hf
if args.from_hf:
# Has error for now
# File "python3.10/site-packages/diffusers/schedulers/scheduling_ddim.py", line 413, in step
# pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# RuntimeError: The size of tensor a (4) must match the size of tensor b (8) at non-singleton dimension 1
print("get pipeline from diffusers")
_is_form_hf = True
return get_pipeline_from_hf(args, model_path, device)
else:
print("get pipeline from source")
_is_form_hf = False
return get_pipeline_from_source(args, model_path, device)


def get_pipeline_from_hf(args, model_path, device):
# Get pipeline from diffusers
# diffusers version >= 0.30
from diffusers import LattePipeline

pipe = LattePipeline.from_pretrained(args.model, torch_dtype=torch.float16).to(
device
)

# Convert to channels_last memory format
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
return pipe


def get_pipeline_from_source(args, model_path, device):
# Get pipeline from https://github.com/siliconflow/dit_latte/
from models.latte_t2v import LatteT2V
from sample.pipeline_latte import LattePipeline

Expand Down Expand Up @@ -182,6 +218,24 @@ def main():
transformer=transformer_model,
).to(device)

return pipe


def main():
args = parse_args()

if os.path.exists(args.model):
model_path = args.model
else:
from huggingface_hub import snapshot_download

model_path = snapshot_download(repo_id=args.model)

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = get_pipeline(args, model_path, device)

if args.compiler == "none":
pass
elif args.compiler == "nexfort":
Expand All @@ -191,7 +245,7 @@ def main():
options = json.loads(args.compiler_config)
else:
# config with string
options = '{"mode": "max-optimize:max-autotune:freezing:benchmark:low-precision", \
options = '{"mode": "O2", \
"memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, \
"triton.fuse_attention_allow_fp16_reduction": false}}'
pipe = compile_pipe(
Expand Down Expand Up @@ -219,60 +273,87 @@ def get_kwarg_inputs():
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder,
**(
dict()
if args.extra_call_kwargs is None
else json.loads(args.extra_call_kwargs)
),
)
if not _is_form_hf:
kwarg_inputs[
"enable_vae_temporal_decoder"
] = args.enable_vae_temporal_decoder
return kwarg_inputs

if args.warmups > 0:
print("=======================================")
print("Begin warmup")
begin = time.time()
for _ in range(args.warmups):
pipe(**get_kwarg_inputs()).video
end = time.time()
print("End warmup")
print(f"Warmup time: {end - begin:.3f}s")
kwarg_inputs = get_kwarg_inputs()
with conditional_context(
args.profile_warmup, torch.profiler.profile()
) as prof_warmup:
with conditional_context(
args.profile_warmup, torch.profiler.record_function("latte warmup")
):
if args.warmups > 0:
print("=======================================")
print("Begin warmup")
begin = time.time()
for _ in range(args.warmups):
out = pipe(**kwarg_inputs)
if _is_form_hf:
videos = out.frames[0]
else:
videos = out.video
end = time.time()
print("End warmup")
print(f"Warmup time: {end - begin:.3f}s")
print("=======================================")
if prof_warmup:
prof_warmup.export_chrome_trace("latte_prof_warmup.json")

with conditional_context(args.profile_run, torch.profiler.profile()) as prof_run:
iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
elif "callback" in inspect.signature(pipe).parameters:
kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
with conditional_context(
args.profile_run, torch.profiler.record_function("latte run")
):
torch.manual_seed(args.seed)
begin = time.time()
out = pipe(**kwarg_inputs)
if _is_form_hf:
videos = out.frames[0]
else:
videos = out.video
end = time.time()

print("=======================================")
print(f"Inference time: {end - begin:.3f}s")
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f"Iterations per second: {iter_per_sec:.3f}")
cuda_mem_max_used = torch.cuda.max_memory_allocated() / (1024**3)
cuda_mem_max_reserved = torch.cuda.max_memory_reserved() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_max_used:.3f}GiB")
print(f"Max reserved CUDA memory : {cuda_mem_max_reserved:.3f}GiB")
print("=======================================")

kwarg_inputs = get_kwarg_inputs()
iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
elif "callback" in inspect.signature(pipe).parameters:
kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
torch.manual_seed(args.seed)
begin = time.time()
videos = pipe(**kwarg_inputs).video
end = time.time()

print("=======================================")
print(f"Inference time: {end - begin:.3f}s")
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f"Iterations per second: {iter_per_sec:.3f}")
cuda_mem_max_used = torch.cuda.max_memory_allocated() / (1024**3)
cuda_mem_max_reserved = torch.cuda.max_memory_reserved() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_max_used:.3f}GiB")
print(f"Max reserved CUDA memory : {cuda_mem_max_reserved:.3f}GiB")
print("=======================================")

if args.output_video is not None:
# export_to_video(output_frames[0], args.output_video, fps=args.fps)
try:
imageio.mimwrite(
args.output_video, videos[0], fps=8, quality=9
) # highest quality is 10, lowest is 0
except:
print("Error when saving {}".format(args.prompt))

else:
print("Please set `--output-video` to save the output video")
with conditional_context(
args.profile_run, torch.profiler.record_function("latte export")
):
if args.output_video is not None:
# export_to_video(output_frames[0], args.output_video, fps=args.fps)
try:
imageio.mimwrite(
args.output_video, videos[0], fps=8, quality=9
) # highest quality is 10, lowest is 0
except:
print("Error when saving {}".format(args.prompt))
else:
print("Please set `--output-video` to save the output video")
if prof_run:
print(prof_run.key_averages().table(sort_by="cuda_time_total", row_limit=100))
prof_run.export_chrome_trace("latte_prof_run.json")


if __name__ == "__main__":
Expand Down
34 changes: 22 additions & 12 deletions onediff_diffusers_extensions/examples/latte/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@

## Environment setup
### Set up Latte

#### From HF diffusers
Note: HF diffusers has bug on LattePipeline on 20240723
Reference: https://huggingface.co/docs/diffusers/main/en/api/pipelines/latte
```bash
# make sure LattePipeline avaliable in HF diffusers(diffusers version >= 0.30)
pip install git+https://github.com/huggingface/diffusers.git@main
```

#### (Optional)From latte project
HF model: https://huggingface.co/maxin-cn/Latte-1
```bash
git clone -b run https://github.com/siliconflow/dit_latte/
Expand Down Expand Up @@ -59,18 +69,18 @@ python3 ./benchmarks/text_to_video_latte.py \
### Metric

#### On A100
| Metric | NVIDIA A100-PCIE-40GB (512 * 512) |
| ------------------------------------------------ | --------------------------------- |
| Data update date(yyyy-mm-dd) | 2024-06-19 |
| PyTorch iteration speed | 1.60 it/s |
| OneDiff iteration speed | 2.27 it/s(+41.9%) |
| PyTorch E2E time | 32.618 s |
| OneDiff E2E time | 22.601 s(-30.7%) |
| PyTorch Max Mem Used | 19.9 GiB |
| OneDiff Max Mem Used | 19.9 GiB |
| PyTorch Warmup with Run time | 33.291 s |
| OneDiff Warmup with Compilation time<sup>1</sup> | 572.877 s |
| OneDiff Warmup with Cache time | 148.068 s |
| Metric | NVIDIA A100-PCIE-40GB (512 * 512) | NVIDIA A100-PCIE-40GB(512 * 512) by strint on ubuntu22 |
| ------------------------------------------------ | --------------------------------- | ------------------------------------------------------- |
| Data update date(yyyy-mm-dd) | 2024-06-19 | 2024-07-23 |
| PyTorch iteration speed | 1.60 it/s | 1.6 it/s |
| OneDiff iteration speed | 2.27 it/s(+41.9%) | 1.723 it/s |
strint marked this conversation as resolved.
Show resolved Hide resolved
| PyTorch E2E time | 32.618 s | 32.497 s |
| OneDiff E2E time | 22.601 s(-30.7%) | 29.64 s |
| PyTorch Max Mem Used | 19.9 GiB | 19.92 GiB |
| OneDiff Max Mem Used | 19.9 GiB | 19.9 GiB |
| PyTorch Warmup with Run time | 33.291 s | 33.129 s |
| OneDiff Warmup with Compilation time<sup>1</sup> | 572.877 s | 737.6 s |
| OneDiff Warmup with Cache time | 148.068 s | 159 s |

<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz. Note this is just for reference, and it varies a lot on different CPU.

Expand Down
Loading