-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a58161f
commit 94be14f
Showing
29 changed files
with
74,222 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,91 @@ | ||
# EnCLAP | ||
Official Implementation of EnCLAP | ||
# EnCLAP: Combining Neural Audio Codec and Audio-Text Joint Embedding for Automated Audio Captioning [ICASSP 2024] | ||
|
||
Jaeyeon Kim, Jaeyoon Jung, Jinjoo Lee, Sang Hoon Woo @ [MAUM AI Inc.](https://github.com/maum-ai), SNU, SSU | ||
|
||
[](https://paperswithcode.com/sota/audio-captioning-on-audiocaps?p=enclap-combining-neural-audio-codec-and-audio) | ||
[](https://arxiv.org/abs/2401.17690) [](https://huggingface.co/spaces/enclap-team/enclap) | ||
|
||
**Abstract** : We propose EnCLAP, a novel framework for automated audio captioning. EnCLAP employs two acoustic representation models, EnCodec and CLAP, along with a pretrained language model, BART. We also introduce a new training objective called masked codec modeling that improves acoustic awareness of the pretrained language model. Experimental results on AudioCaps and Clotho demonstrate that our model surpasses the performance of baseline models. Source code will be available at https://github.com/jaeyeonkim99/EnCLAP. An online demo is available at https://huggingface.co/spaces/enclap-team/enclap. | ||
|
||
 | ||
|
||
|
||
# Requirements | ||
|
||
## Environment | ||
- We used `torch==1.13.0` and `python==3.9.12` for our experiment. | ||
- Please install required depedencies through `pip install -r requirements.txt`. | ||
- `wget`, `java`, `unzip` are required to use `aac_metrics` library. Please run `aac-metrics-download` after installing the library. | ||
- `libsndfile1` is required to use `soundfile`. | ||
- `torchaudio` version should match the version of `torch` and `cuda`. Check [here](https://download.pytorch.org/whl/torchaudio/) for appropriate version. | ||
|
||
## CLAP and EnCodec | ||
- In our experiments, we used CLAP checkpoint trained on `LAION-audio-630k` and `AudioSet`([630k-audioset-fusion-best.pt](https://huggingface.co/lukewys/laion_clap/blob/main/630k-audioset-fusion-best.pt)). See details of pretrained CLAP checkpoints in [LAION-AI/CLAP](https://github.com/LAION-AI/CLAP?tab=readme-ov-file). | ||
|
||
|
||
# How to Use | ||
|
||
## Training | ||
1. **Preprocess the dataset**: For training, you must first convert the audio files into their respective CLAP embeddings and EnCodec sequences. Once you have the converted data, you must write CSV files mapping each example to its CLAP embedding and EnCodec sequence files. The example CSV files for AudioCaps and Clotho are provided at [csv](csv/). See [data/README.md](data/README.md) for further details. | ||
2. **Setup the training config file**: We included a number of training configuration files we used for our experiments at [cfg](cfg/). Feel free to modify them to fit your needs. | ||
``` | ||
# Paths to save the experiment results | ||
output_dir: /output | ||
# Paths of the train csv file | ||
train_file: csv/audiocaps/train.csv | ||
# Paths of the validation csv file | ||
validation_file: csv/audiocaps/valid.csv | ||
# Root directory of the preprocessed EnCodec entries | ||
encodec_base_path: /data/audiocaps/encodec | ||
# Root directory of the preprocessed CLAP embeddings | ||
clap_base_path: /data/audiocaps/clap | ||
``` | ||
|
||
3. **Run the training script**: Our training script is based on [Accelerate](https://huggingface.co/docs/accelerate/index), so you may need to setup Accelerate before running the training script. You can designate the path to the training config by modifying `CFG_PATH` in the training script. | ||
|
||
``` | ||
accelerate config # Set up Accelerate | ||
sh train.sh # Run the training script | ||
``` | ||
|
||
## Evaluate | ||
- The evaluation dataset may be either preprocessed CLAP embeddings and EnCodec sequences or raw wav files. Similarly to training, you must provide a CSV file containing the dataset. See [csv](csv/) for examples. | ||
|
||
``` | ||
# From preprocessed data | ||
python evaluate.py --ckpt checkpoint_directory_with_config --clap_ckpt clap_ckpt_path --test_csv test.csv --from_preprocessed --encodec_path data/encodec_embeddings --clap_path data/clap_embeddings | ||
# From wav files | ||
python evaluate.py --ckpt checkpoint_directory_with_config --clap_ckpt clap_ckpt_path --test_csv test.csv --audio_path data/wav_files | ||
``` | ||
|
||
- You can save the predictions by adding CSV path with `--save_path`. You can evaluate on other AAC metrics by modifying `metric_list` on the top of `evaluate.py`. | ||
- To reproduce our results, we recommend you to evaluate several best checkpoints based on `valid/spider` score (we evalauted 5). The score reported during the training is not 100% accurate due to DDP and dataloader padding. | ||
|
||
## Inference | ||
- You can infer from wav files using the following code. | ||
``` | ||
python inference.py --ckpt checkpoint_directory_with_config --clap_ckpt clap_ckpt_path --input input.wav | ||
``` | ||
|
||
- You can also run the gradio demo with your own checkpoint. | ||
``` | ||
python gradio_app.py --ckpt checkpoint_directory_with_config --clap_ckpt clap_ckpt_path --device cuda | ||
``` | ||
# Pretrained Checkpoints | ||
Pretrained checkpoints will be available soon. | ||
|
||
# Citation | ||
``` | ||
@article{kim2024enclap, | ||
title={EnCLAP: Combining Neural Audio Codec and Audio-Text Joint Embedding for Automated Audio Captioning}, | ||
author={Kim, Jaeyeon and Jung, Jaeyoon and Lee, Jinjoo and Woo, Sang Hoon}, | ||
journal={arXiv preprint arXiv:2401.17690}, | ||
year={2024} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/audiocaps/train.csv | ||
validation_file: csv/audiocaps/valid.csv | ||
encodec_base_path: /data/audiocaps/encodec | ||
clap_base_path: /data/audiocaps/clap | ||
tokenizer_name: facebook/bart-base | ||
config_name_or_path: facebook/bart-base | ||
model_name_or_path: facebook/bart-base | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 16 | ||
per_device_eval_batch_size: 16 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 6.5e-5 # peak lr | ||
num_warmup_steps: 2000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/audiocaps/train.csv | ||
validation_file: csv/audiocaps/valid.csv | ||
encodec_base_path: /data/audiocaps/encodec | ||
clap_base_path: /data/audiocaps/clap | ||
tokenizer_name: facebook/bart-large | ||
config_name_or_path: facebook/bart-large | ||
model_name_or_path: facebook/bart-large | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 64 | ||
per_device_eval_batch_size: 64 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 3e-5 # peak lr | ||
num_warmup_steps: 2000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/clotho/train.csv | ||
validation_file: csv/clotho/valid.csv | ||
encodec_base_path: /data/clotho/encodec | ||
clap_base_path: /data/clotho/clap | ||
tokenizer_name: facebook/bart-base | ||
config_name_or_path: facebook/bart-base | ||
model_name_or_path: facebook/bart-base | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 64 | ||
per_device_eval_batch_size: 64 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 4e-5 # peak lr | ||
num_warmup_steps: 1000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/clotho/train.csv | ||
validation_file: csv/clotho/valid.csv | ||
encodec_base_path: /data/clotho/encodec | ||
clap_base_path: /data/clotho/clap | ||
tokenizer_name: facebook/bart-large | ||
config_name_or_path: facebook/bart-large | ||
model_name_or_path: facebook/bart-large | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 64 | ||
per_device_eval_batch_size: 64 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 2.5e-5 # peak lr | ||
num_warmup_steps: 1000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/clotho/train.csv | ||
validation_file: csv/clotho/valid.csv | ||
encodec_base_path: /data/clotho/encodec | ||
clap_base_path: /data/clotho/clap | ||
tokenizer_name: facebook/bart-base | ||
config_name_or_path: facebook/bart-base | ||
model_name_or_path: /data/enclap_audiocaps | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 64 | ||
per_device_eval_batch_size: 64 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 2e-5 # peak lr | ||
num_warmup_steps: 0 | ||
time_scale: 1000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Experiment Config for each experiment | ||
output_dir: /output | ||
logging_dir: runs/tb_log | ||
logging_steps: 10 | ||
seed: 1115 | ||
train_file: csv/clotho/train.csv | ||
validation_file: csv/clotho/valid.csv | ||
encodec_base_path: /data/clotho/encodec | ||
clap_base_path: /data/clotho/clap | ||
tokenizer_name: facebook/bart-large | ||
config_name_or_path: facebook/bart-large | ||
model_name_or_path: /data/enclap_audiocaps | ||
eval_num_captions: 5 | ||
overwrite_output_dir: False | ||
|
||
# Basic Config | ||
encodec_masking_prob: 0.15 | ||
encodec_masking_span: 10 | ||
num_train_epochs: 15 | ||
max_train_steps: null | ||
gradient_accumulation_steps: 1 | ||
per_device_train_batch_size: 64 | ||
per_device_eval_batch_size: 64 | ||
split_batches: true | ||
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps | ||
resume_from_checkpoint: null | ||
|
||
# Generation Config | ||
max_target_length: 128 | ||
val_max_target_length: 50 | ||
|
||
# Training Hyperparameters | ||
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial", | ||
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau" | ||
lr_scheduler_type: inverse_sqrt | ||
learning_rate: 1.25e-5 # peak lr | ||
num_warmup_steps: 0 | ||
time_scale: 1000 | ||
weight_decay: 0.01 | ||
max_grad_norm: 1.0 | ||
|
||
# Others | ||
with_tracking: true | ||
report_to: tensorboard | ||
ignore_pad_token_for_loss: true | ||
preprocessing_num_workers: 32 | ||
use_slow_tokenizer: false | ||
overwrite_cache: false | ||
pad_to_max_length: false |
Oops, something went wrong.