Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

--split_attn increases training speed, contrary to docs #65

Open
xmodo99 opened this issue Jan 26, 2025 · 3 comments
Open

--split_attn increases training speed, contrary to docs #65

xmodo99 opened this issue Jan 26, 2025 · 3 comments

Comments

@xmodo99
Copy link

xmodo99 commented Jan 26, 2025

I'm running python 3.10.3 on Linux Mint, with an RTX 3090.

Contrary to the docs --split_attn increases the speed by cutting the time/it in half

Without split attention, I'm getting 27s/it, but with split attention it's 13s/it.

Does it make logical sense that this would happen? If it does we should update the docs

@FurkanGozukara
Copy link

Good question

@kohya-ss
Copy link
Owner

That's very interesting. --split_attn eliminates the need for the attention mask, so maybe that helps. Could you tell me which one you're using: spda, xformers, or flash_attn?

@xmodo99
Copy link
Author

xmodo99 commented Feb 1, 2025

I was using sdpa like the example in the docs. Here's my full command:

accelerate launch --num_cpu_threads_per_process 2 --mixed_precision bf16 hv_train_network.py --dit ./ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.safetensors --dataset_config ./training_config.toml --sdpa --mixed_precision bf16 --fp8_base --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing --max_data_loader_n_workers 2 --persistent_data_loader_workers --network_module networks.lora --network_dim 32 --timestep_sampling shift --discrete_flow_shift 7.0 --max_train_epochs 100 --save_every_n_steps {sample_every_n_steps} --seed 42 --output_dir ./train_output/{lora_name}/ --output_name {lora_name} --split_attn --logging_dir ./tensorLogs --log_prefix {lora_name}_ --vae ./ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 --sample_prompts ./sample_prompts.txt --text_encoder1 ./ckpts/text_encoder/llava_llama3_fp16.safetensors --text_encoder2 ./ckpts/text_encoder_2/clip_l.safetensors --sample_every_n_steps {sample_every_n_steps}

Although understandably the sampling doesn't affect the training speed so can probably be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants