diff --git a/README.md b/README.md new file mode 100644 index 0000000..160ccf9 --- /dev/null +++ b/README.md @@ -0,0 +1,198 @@ +
+

Grad-SVC based Grad-TTS from HUAWEI Noah's Ark Lab

+ +This project is named as Grad-SVC, or GVC for short. Its core technology is diffusion, but so different from other diffusion based SVC models. Codes are adapted from Grad-TTS and so-vits-svc-5.0. So the features from so-vits-svc-5.0 will be used in this project. + +The project will be completed in the coming months ~~~ +
+ +## Setup Environment +1. Install project dependencies + + ```shell + pip install -r requirements.txt + ``` + +2. Download the Timbre Encoder: [Speaker-Encoder by @mueller91](https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3), put `best_model.pth.tar` into `speaker_pretrain/`. + +3. Download [hubert_soft model](https://github.com/bshall/hubert/releases/tag/v0.1),put `hubert-soft-0d54a1f4.pt` into `hubert_pretrain/`. + +4. Download pretrained [nsf_bigvgan_pretrain_32K.pth](https://github.com/PlayVoice/NSF-BigVGAN/releases/augment), and put it into `bigvgan_pretrain/`. + +5. Download pretrain model [gvc.pretrain.pth](), and put it into `grad_pretrain/`. + ```shell + python gvc_inference.py --config configs/base.yaml --model ./grad_pretrain/gvc.pretrain.pth --spk ./configs/singers/singer0001.npy --wave test.wav + ``` + +## Dataset preparation +Put the dataset into the `data_raw` directory following the structure below. +``` +data_raw +├───speaker0 +│ ├───000001.wav +│ ├───... +│ └───000xxx.wav +└───speaker1 + ├───000001.wav + ├───... + └───000xxx.wav +``` + +## Data preprocessing +After preprocessing you will get an output with following structure. +``` +data_gvc/ +└── waves-16k +│ └── speaker0 +│ │ ├── 000001.wav +│ │ └── 000xxx.wav +│ └── speaker1 +│ ├── 000001.wav +│ └── 000xxx.wav +└── waves-32k +│ └── speaker0 +│ │ ├── 000001.wav +│ │ └── 000xxx.wav +│ └── speaker1 +│ ├── 000001.wav +│ └── 000xxx.wav +└── mel +│ └── speaker0 +│ │ ├── 000001.mel.pt +│ │ └── 000xxx.mel.pt +│ └── speaker1 +│ ├── 000001.mel.pt +│ └── 000xxx.mel.pt +└── pitch +│ └── speaker0 +│ │ ├── 000001.pit.npy +│ │ └── 000xxx.pit.npy +│ └── speaker1 +│ ├── 000001.pit.npy +│ └── 000xxx.pit.npy +└── hubert +│ └── speaker0 +│ │ ├── 000001.vec.npy +│ │ └── 000xxx.vec.npy +│ └── speaker1 +│ ├── 000001.vec.npy +│ └── 000xxx.vec.npy +└── speaker +│ └── speaker0 +│ │ ├── 000001.spk.npy +│ │ └── 000xxx.spk.npy +│ └── speaker1 +│ ├── 000001.spk.npy +│ └── 000xxx.spk.npy +└── singer + ├── speaker0.spk.npy + └── speaker1.spk.npy +``` + +1. Re-sampling + - Generate audio with a sampling rate of 16000Hz in `./data_gvc/waves-16k` + ``` + python prepare/preprocess_a.py -w ./data_raw -o ./data_gvc/waves-16k -s 16000 + ``` + + - Generate audio with a sampling rate of 32000Hz in `./data_gvc/waves-32k` + ``` + python prepare/preprocess_a.py -w ./data_raw -o ./data_gvc/waves-32k -s 32000 + ``` +2. Use 16K audio to extract pitch + ``` + python prepare/preprocess_f0.py -w data_gvc/waves-16k/ -p data_gvc/pitch + ``` +3. use 32k audio to extract mel + ``` + python prepare/preprocess_spec.py -w data_gvc/waves-32k/ -s data_gvc/mel + ``` +4. Use 16K audio to extract hubert + ``` + python prepare/preprocess_hubert.py -w data_gvc/waves-16k/ -v data_gvc/hubert + ``` +5. Use 16k audio to extract timbre code + ``` + python prepare/preprocess_speaker.py data_gvc/waves-16k/ data_gvc/speaker + ``` +6. Extract the average value of the timbre code for inference + ``` + python prepare/preprocess_speaker_ave.py data_gvc/speaker/ data_gvc/singer + ``` +8. Use 32k audio to generate training index + ``` + python prepare/preprocess_train.py + ``` +9. Training file debugging + ``` + python prepare/preprocess_zzz.py + ``` + +## Train +1. Start training + ``` + python gvc_trainer.py + ``` +2. Resume training + ``` + python gvc_trainer.py -p logs/grad_svc/grad_svc_***.pth + ``` +3. Log visualization + ``` + tensorboard --logdir logs/ + ``` + +## Loss +![grad_svc_loss](./assets/grad_svc_loss.jpg) + +![grad_svc_mel](./assets/grad_svc_mel.jpg) + +## Inference + +1. Export inference model + ``` + python gvc_export.py --checkpoint_path logs/grad_svc/grad_svc_***.pt + ``` + +2. Inference + - if there is no need to adjust `f0`, just run the following command. + ``` + python gvc_inference.py --model gvc.pth --spk ./data_gvc/singer/your_singer.spk.npy --wave test.wav --shift 0 + ``` + - if `f0` will be adjusted manually, follow the steps: + + 1. use hubert to extract content vector + ``` + python hubert/inference.py -w test.wav -v test.vec.npy + ``` + 2. extract the F0 parameter to the csv text format + ``` + python pitch/inference.py -w test.wav -p test.csv + ``` + 3. final inference + ``` + python gvc_inference.py --model gvc.pth --spk ./data_gvc/singer/your_singer.spk.npy --wave test.wav --vec test.vec.npy --pit test.csv --shift 0 + ``` + +3. Convert mel to wave + ``` + python gvc_inference_wave.py --mel gvc_out.mel.pt --pit gvc_tmp.pit.csv + ``` + +## Code sources and references + +https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS + +https://github.com/facebookresearch/speech-resynthesis [paper](https://arxiv.org/abs/2104.00355) + +https://github.com/jaywalnut310/vits [paper](https://arxiv.org/abs/2106.06103) + +https://github.com/NVIDIA/BigVGAN [paper](https://arxiv.org/abs/2206.04658) + +https://github.com/mindslab-ai/univnet [paper](https://arxiv.org/abs/2106.07889) + +https://github.com/mozilla/TTS + +https://github.com/bshall/soft-vc + +https://github.com/maxrmorrison/torchcrepe diff --git a/assets/grad_svc_loss.jpg b/assets/grad_svc_loss.jpg new file mode 100644 index 0000000..ceea76c Binary files /dev/null and b/assets/grad_svc_loss.jpg differ diff --git a/assets/grad_svc_mel.jpg b/assets/grad_svc_mel.jpg new file mode 100644 index 0000000..476a9a6 Binary files /dev/null and b/assets/grad_svc_mel.jpg differ diff --git a/bigvgan/LICENSE b/bigvgan/LICENSE new file mode 100644 index 0000000..328ed6e --- /dev/null +++ b/bigvgan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 PlayVoice + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/bigvgan/README.md b/bigvgan/README.md new file mode 100644 index 0000000..7816c1e --- /dev/null +++ b/bigvgan/README.md @@ -0,0 +1,138 @@ +
+

Neural Source-Filter BigVGAN

+ Just For Fun +
+ +![nsf_bigvgan_mel](https://github.com/PlayVoice/NSF-BigVGAN/assets/16432329/eebb8dca-a8d3-4e69-b02c-632a3a1cdd6a) + +## Dataset preparation + +Put the dataset into the data_raw directory according to the following file structure +```shell +data_raw +├───speaker0 +│ ├───000001.wav +│ ├───... +│ └───000xxx.wav +└───speaker1 + ├───000001.wav + ├───... + └───000xxx.wav +``` + +## Install dependencies + +- 1 software dependency + + > pip install -r requirements.txt + +- 2 download [release](https://github.com/PlayVoice/NSF-BigVGAN/releases/tag/debug) model, and test + + > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --wave test.wav + +## Data preprocessing + +- 1, re-sampling: 32kHz + + > python prepare/preprocess_a.py -w ./data_raw -o ./data_bigvgan/waves-32k + +- 3, extract pitch + + > python prepare/preprocess_f0.py -w data_bigvgan/waves-32k/ -p data_bigvgan/pitch + +- 4, extract mel: [100, length] + + > python prepare/preprocess_spec.py -w data_bigvgan/waves-32k/ -s data_bigvgan/mel + +- 5, generate training index + + > python prepare/preprocess_train.py + +```shell +data_bigvgan/ +│ +└── waves-32k +│ └── speaker0 +│ │ ├── 000001.wav +│ │ └── 000xxx.wav +│ └── speaker1 +│ ├── 000001.wav +│ └── 000xxx.wav +└── pitch +│ └── speaker0 +│ │ ├── 000001.pit.npy +│ │ └── 000xxx.pit.npy +│ └── speaker1 +│ ├── 000001.pit.npy +│ └── 000xxx.pit.npy +└── mel + └── speaker0 + │ ├── 000001.mel.pt + │ └── 000xxx.mel.pt + └── speaker1 + ├── 000001.mel.pt + └── 000xxx.mel.pt + +``` + +## Train + +- 1, start training + + > python nsf_bigvgan_trainer.py -c configs/nsf_bigvgan.yaml -n nsf_bigvgan + +- 2, resume training + + > python nsf_bigvgan_trainer.py -c configs/nsf_bigvgan.yaml -n nsf_bigvgan -p chkpt/nsf_bigvgan/***.pth + +- 3, view log + + > tensorboard --logdir logs/ + + +## Inference + +- 1, export inference model + + > python nsf_bigvgan_export.py --config configs/maxgan.yaml --checkpoint_path chkpt/nsf_bigvgan/***.pt + +- 2, extract mel + + > python spec/inference.py -w test.wav -m test.mel.pt + +- 3, extract F0 + + > python pitch/inference.py -w test.wav -p test.csv + +- 4, infer + + > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --wave test.wav + + or + + > python nsf_bigvgan_inference.py --config configs/nsf_bigvgan.yaml --model nsf_bigvgan_g.pth --mel test.mel.pt --pit test.csv + +## Augmentation of mel +For the over smooth output of acoustic model, we use gaussian blur for mel when train vocoder +``` +# gaussian blur +model_b = get_gaussian_kernel(kernel_size=5, sigma=2, channels=1).to(device) +# mel blur +mel_b = mel[:, None, :, :] +mel_b = model_b(mel_b) +mel_b = torch.squeeze(mel_b, 1) +mel_r = torch.rand(1).to(device) * 0.5 +mel_b = (1 - mel_r) * mel_b + mel_r * mel +# generator +optim_g.zero_grad() +fake_audio = model_g(mel_b, pit) +``` +![mel_gaussian_blur](https://github.com/PlayVoice/NSF-BigVGAN/assets/16432329/7fa96ef7-5e3b-4ae6-bc61-9b6da3b9d0b9) + +## Source of code and References + +https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf + +https://github.com/mindslab-ai/univnet [[paper]](https://arxiv.org/abs/2106.07889) + +https://github.com/NVIDIA/BigVGAN [[paper]](https://arxiv.org/abs/2206.04658) \ No newline at end of file diff --git a/bigvgan/configs/nsf_bigvgan.yaml b/bigvgan/configs/nsf_bigvgan.yaml new file mode 100644 index 0000000..809b129 --- /dev/null +++ b/bigvgan/configs/nsf_bigvgan.yaml @@ -0,0 +1,60 @@ +data: + train_file: 'files/train.txt' + val_file: 'files/valid.txt' +############################# +train: + num_workers: 4 + batch_size: 8 + optimizer: 'adam' + seed: 1234 + adam: + lr: 0.0002 + beta1: 0.8 + beta2: 0.99 + mel_lamb: 5 + stft_lamb: 2.5 + pretrain: '' + lora: False +############################# +audio: + n_mel_channels: 100 + segment_length: 12800 # Should be multiple of 320 + filter_length: 1024 + hop_length: 320 # WARNING: this can't be changed. + win_length: 1024 + sampling_rate: 32000 + mel_fmin: 40.0 + mel_fmax: 16000.0 +############################# +gen: + mel_channels: 100 + upsample_rates: [5,4,2,2,2,2] + upsample_kernel_sizes: [15,8,4,4,4,4] + upsample_initial_channel: 320 + resblock_kernel_sizes: [3,7,11] + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] +############################# +mpd: + periods: [2,3,5,7,11] + kernel_size: 5 + stride: 3 + use_spectral_norm: False + lReLU_slope: 0.2 +############################# +mrd: + resolutions: "[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]" # (filter_length, hop_length, win_length) + use_spectral_norm: False + lReLU_slope: 0.2 +############################# +dist_config: + dist_backend: "nccl" + dist_url: "tcp://localhost:54321" + world_size: 1 +############################# +log: + info_interval: 100 + eval_interval: 1000 + save_interval: 10000 + num_audio: 6 + pth_dir: 'chkpt' + log_dir: 'logs' diff --git a/bigvgan/model/__init__.py b/bigvgan/model/__init__.py new file mode 100644 index 0000000..986a0cf --- /dev/null +++ b/bigvgan/model/__init__.py @@ -0,0 +1 @@ +from .alias.act import SnakeAlias \ No newline at end of file diff --git a/bigvgan/model/alias/__init__.py b/bigvgan/model/alias/__init__.py new file mode 100644 index 0000000..a2318b6 --- /dev/null +++ b/bigvgan/model/alias/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/bigvgan/model/alias/act.py b/bigvgan/model/alias/act.py new file mode 100644 index 0000000..308344f --- /dev/null +++ b/bigvgan/model/alias/act.py @@ -0,0 +1,129 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import sin, pow +from torch.nn import Parameter +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta = x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze( + 0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + return x + + +class Mish(nn.Module): + """ + Mish activation function is proposed in "Mish: A Self + Regularized Non-Monotonic Neural Activation Function" + paper, https://arxiv.org/abs/1908.08681. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SnakeAlias(nn.Module): + def __init__(self, + channels, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = SnakeBeta(channels, alpha_logscale=True) + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/bigvgan/model/alias/filter.py b/bigvgan/model/alias/filter.py new file mode 100644 index 0000000..7ad6ea8 --- /dev/null +++ b/bigvgan/model/alias/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/bigvgan/model/alias/resample.py b/bigvgan/model/alias/resample.py new file mode 100644 index 0000000..750e6c3 --- /dev/null +++ b/bigvgan/model/alias/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/bigvgan/model/bigv.py b/bigvgan/model/bigv.py new file mode 100644 index 0000000..029362c --- /dev/null +++ b/bigvgan/model/bigv.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm +from .alias.act import SnakeAlias + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +class AMPBlock(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(AMPBlock, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + # total number of conv layers + self.num_layers = len(self.convs1) + len(self.convs2) + + # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + SnakeAlias(channels) for _ in range(self.num_layers) + ]) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) \ No newline at end of file diff --git a/bigvgan/model/generator.py b/bigvgan/model/generator.py new file mode 100644 index 0000000..3406c32 --- /dev/null +++ b/bigvgan/model/generator.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import weight_norm +from torch.nn.utils import remove_weight_norm + +from .nsf import SourceModuleHnNSF +from .bigv import init_weights, AMPBlock, SnakeAlias + + +class Generator(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, hp): + super(Generator, self).__init__() + self.hp = hp + self.num_kernels = len(hp.gen.resblock_kernel_sizes) + self.num_upsamples = len(hp.gen.upsample_rates) + # pre conv + self.conv_pre = nn.utils.weight_norm( + Conv1d(hp.gen.mel_channels, hp.gen.upsample_initial_channel, 7, 1, padding=3)) + # nsf + self.f0_upsamp = torch.nn.Upsample( + scale_factor=np.prod(hp.gen.upsample_rates)) + self.m_source = SourceModuleHnNSF(sampling_rate=hp.audio.sampling_rate) + self.noise_convs = nn.ModuleList() + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)): + # print(f'ups: {i} {k}, {u}, {(k - u) // 2}') + # base + self.ups.append( + weight_norm( + ConvTranspose1d( + hp.gen.upsample_initial_channel // (2 ** i), + hp.gen.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2) + ) + ) + # nsf + if i + 1 < len(hp.gen.upsample_rates): + stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:]) + stride_f0 = int(stride_f0) + self.noise_convs.append( + Conv1d( + 1, + hp.gen.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append( + Conv1d(1, hp.gen.upsample_initial_channel // + (2 ** (i + 1)), kernel_size=1) + ) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = hp.gen.upsample_initial_channel // (2 ** (i + 1)) + for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes): + self.resblocks.append(AMPBlock(ch, k, d)) + + # post conv + self.activation_post = SnakeAlias(ch) + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + # weight initialization + self.ups.apply(init_weights) + + def forward(self, x, f0, train=True): + # nsf + f0 = f0[:, None] + f0 = self.f0_upsamp(f0).transpose(1, 2) + har_source = self.m_source(f0) + har_source = har_source.transpose(1, 2) + # pre conv + if train: + x = x + torch.randn_like(x) * 0.1 # Perturbation + x = self.conv_pre(x) + x = x * torch.tanh(F.softplus(x)) + + for i in range(self.num_upsamples): + # upsampling + x = self.ups[i](x) + # nsf + x_source = self.noise_convs[i](har_source) + x = x + x_source + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + + def eval(self, inference=False): + super(Generator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def inference(self, mel, f0): + MAX_WAV_VALUE = 32768.0 + audio = self.forward(mel, f0, False) + audio = audio.squeeze() # collapse all dimension except time axis + audio = MAX_WAV_VALUE * audio + audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) + audio = audio.short() + return audio + + def pitch2wav(self, f0): + MAX_WAV_VALUE = 32768.0 + # nsf + f0 = f0[:, None] + f0 = self.f0_upsamp(f0).transpose(1, 2) + har_source = self.m_source(f0) + audio = har_source.transpose(1, 2) + audio = audio.squeeze() # collapse all dimension except time axis + audio = MAX_WAV_VALUE * audio + audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) + audio = audio.short() + return audio diff --git a/bigvgan/model/nsf.py b/bigvgan/model/nsf.py new file mode 100644 index 0000000..1e9e6c7 --- /dev/null +++ b/bigvgan/model/nsf.py @@ -0,0 +1,394 @@ +import torch +import numpy as np +import sys +import torch.nn.functional as torch_nn_func + + +class PulseGen(torch.nn.Module): + """Definition of Pulse train generator + + There are many ways to implement pulse generator. + Here, PulseGen is based on SinGen. For a perfect + """ + + def __init__(self, samp_rate, pulse_amp=0.1, noise_std=0.003, voiced_threshold=0): + super(PulseGen, self).__init__() + self.pulse_amp = pulse_amp + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.noise_std = noise_std + self.l_sinegen = SineGen( + self.sampling_rate, + harmonic_num=0, + sine_amp=self.pulse_amp, + noise_std=0, + voiced_threshold=self.voiced_threshold, + flag_for_pulse=True, + ) + + def forward(self, f0): + """Pulse train generator + pulse_train, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output pulse_train: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + + Note: self.l_sine doesn't make sure that the initial phase of + a voiced segment is np.pi, the first pulse in a voiced segment + may not be at the first time step within a voiced segment + """ + with torch.no_grad(): + sine_wav, uv, noise = self.l_sinegen(f0) + + # sine without additive noise + pure_sine = sine_wav - noise + + # step t corresponds to a pulse if + # sine[t] > sine[t+1] & sine[t] > sine[t-1] + # & sine[t-1], sine[t+1], and sine[t] are voiced + # or + # sine[t] is voiced, sine[t-1] is unvoiced + # we use torch.roll to simulate sine[t+1] and sine[t-1] + sine_1 = torch.roll(pure_sine, shifts=1, dims=1) + uv_1 = torch.roll(uv, shifts=1, dims=1) + uv_1[:, 0, :] = 0 + sine_2 = torch.roll(pure_sine, shifts=-1, dims=1) + uv_2 = torch.roll(uv, shifts=-1, dims=1) + uv_2[:, -1, :] = 0 + + loc = (pure_sine > sine_1) * (pure_sine > sine_2) \ + * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \ + + (uv_1 < 1) * (uv > 0) + + # pulse train without noise + pulse_train = pure_sine * loc + + # additive noise to pulse train + # note that noise from sinegen is zero in voiced regions + pulse_noise = torch.randn_like(pure_sine) * self.noise_std + + # with additive noise on pulse, and unvoiced regions + pulse_train += pulse_noise * loc + pulse_noise * (1 - uv) + return pulse_train, sine_wav, uv, pulse_noise + + +class SignalsConv1d(torch.nn.Module): + """Filtering input signal with time invariant filter + Note: FIRFilter conducted filtering given fixed FIR weight + SignalsConv1d convolves two signals + Note: this is based on torch.nn.functional.conv1d + + """ + + def __init__(self): + super(SignalsConv1d, self).__init__() + + def forward(self, signal, system_ir): + """output = forward(signal, system_ir) + + signal: (batchsize, length1, dim) + system_ir: (length2, dim) + + output: (batchsize, length1, dim) + """ + if signal.shape[-1] != system_ir.shape[-1]: + print("Error: SignalsConv1d expects shape:") + print("signal (batchsize, length1, dim)") + print("system_id (batchsize, length2, dim)") + print("But received signal: {:s}".format(str(signal.shape))) + print(" system_ir: {:s}".format(str(system_ir.shape))) + sys.exit(1) + padding_length = system_ir.shape[0] - 1 + groups = signal.shape[-1] + + # pad signal on the left + signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), (padding_length, 0)) + # prepare system impulse response as (dim, 1, length2) + # also flip the impulse response + ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), dims=[2]) + # convolute + output = torch_nn_func.conv1d(signal_pad, ir, groups=groups) + return output.permute(0, 2, 1) + + +class CyclicNoiseGen_v1(torch.nn.Module): + """CyclicnoiseGen_v1 + Cyclic noise with a single parameter of beta. + Pytorch v1 implementation assumes f_t is also fixed + """ + + def __init__(self, samp_rate, noise_std=0.003, voiced_threshold=0): + super(CyclicNoiseGen_v1, self).__init__() + self.samp_rate = samp_rate + self.noise_std = noise_std + self.voiced_threshold = voiced_threshold + + self.l_pulse = PulseGen( + samp_rate, + pulse_amp=1.0, + noise_std=noise_std, + voiced_threshold=voiced_threshold, + ) + self.l_conv = SignalsConv1d() + + def noise_decay(self, beta, f0mean): + """decayed_noise = noise_decay(beta, f0mean) + decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate) + + beta: (dim=1) or (batchsize=1, 1, dim=1) + f0mean (batchsize=1, 1, dim=1) + + decayed_noise (batchsize=1, length, dim=1) + """ + with torch.no_grad(): + # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T + # truncate the noise when decayed by -40 dB + length = 4.6 * self.samp_rate / f0mean + length = length.int() + time_idx = torch.arange(0, length, device=beta.device) + time_idx = time_idx.unsqueeze(0).unsqueeze(2) + time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2]) + + noise = torch.randn(time_idx.shape, device=beta.device) + + # due to Pytorch implementation, use f0_mean as the f0 factor + decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate) + return noise * self.noise_std * decay + + def forward(self, f0s, beta): + """Producde cyclic-noise""" + # pulse train + pulse_train, sine_wav, uv, noise = self.l_pulse(f0s) + pure_pulse = pulse_train - noise + + # decayed_noise (length, dim=1) + if (uv < 1).all(): + # all unvoiced + cyc_noise = torch.zeros_like(sine_wav) + else: + f0mean = f0s[uv > 0].mean() + + decayed_noise = self.noise_decay(beta, f0mean)[0, :, :] + # convolute + cyc_noise = self.l_conv(pure_pulse, decayed_noise) + + # add noise in invoiced segments + cyc_noise = cyc_noise + noise * (1.0 - uv) + return cyc_noise, pulse_train, sine_wav, uv, noise + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def _f02sine(self, f0_values): + """f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand( + f0_values.shape[0], f0_values.shape[2], device=f0_values.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + + # generate sine waveforms + sine_waves = self._f02sine(f0_buf) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves + + +class SourceModuleCycNoise_v1(torch.nn.Module): + """SourceModuleCycNoise_v1 + SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + + noise_std: std of Gaussian noise (default: 0.003) + voiced_threshold: threshold to set U/V given F0 (default: 0) + + cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0): + super(SourceModuleCycNoise_v1, self).__init__() + self.sampling_rate = sampling_rate + self.noise_std = noise_std + self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std, voiced_threshod) + + def forward(self, f0_upsamped, beta): + """ + cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + # source for harmonic branch + cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.noise_std / 3 + return cyc, noise, uv + + +class SourceModuleHnNSF(torch.nn.Module): + def __init__( + self, + sampling_rate=32000, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + ): + super(SourceModuleHnNSF, self).__init__() + harmonic_num = 10 + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod + ) + + # to merge source harmonics into a single excitation + self.l_tanh = torch.nn.Tanh() + self.register_buffer('merge_w', torch.FloatTensor([[ + 0.2942, -0.2243, 0.0033, -0.0056, -0.0020, -0.0046, + 0.0221, -0.0083, -0.0241, -0.0036, -0.0581]])) + self.register_buffer('merge_b', torch.FloatTensor([0.0008])) + + def forward(self, x): + """ + Sine_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + """ + # source for harmonic branch + sine_wavs = self.l_sin_gen(x) + sine_wavs = torch_nn_func.linear( + sine_wavs, self.merge_w) + self.merge_b + sine_merge = self.l_tanh(sine_wavs) + return sine_merge diff --git a/bigvgan_pretrain/README.md b/bigvgan_pretrain/README.md new file mode 100644 index 0000000..0bf434e --- /dev/null +++ b/bigvgan_pretrain/README.md @@ -0,0 +1,5 @@ +Path for: + + nsf_bigvgan_pretrain_32K.pth + + DownLoad link:https://github.com/PlayVoice/NSF-BigVGAN/releases/tag/augment diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..a7e52cc --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,40 @@ +train: + seed: 37 + train_files: "files/train.txt" + valid_files: "files/valid.txt" + log_dir: 'logs/grad_svc' + n_epochs: 10000 + learning_rate: 1e-4 + batch_size: 16 + test_size: 4 + test_step: 1 + save_step: 1 + pretrain: "" +############################# +data: + segment_size: 16000 # WARNING: base on hop_length + max_wav_value: 32768.0 + sampling_rate: 32000 + filter_length: 1024 + hop_length: 320 + win_length: 1024 + mel_channels: 100 + mel_fmin: 40.0 + mel_fmax: 16000.0 +############################# +grad: + n_mels: 100 + n_vecs: 256 + n_pits: 256 + n_spks: 256 + n_embs: 64 + + # encoder parameters + n_enc_channels: 192 + filter_channels: 768 + + # decoder parameters + dec_dim: 64 + beta_min: 0.05 + beta_max: 20.0 + pe_scale: 1000 diff --git a/grad/LICENSE b/grad/LICENSE new file mode 100644 index 0000000..e1c1351 --- /dev/null +++ b/grad/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2021 Huawei Technologies Co., Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/grad/__init__.py b/grad/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grad/base.py b/grad/base.py new file mode 100644 index 0000000..7294dcb --- /dev/null +++ b/grad/base.py @@ -0,0 +1,29 @@ +import numpy as np +import torch + + +class BaseModule(torch.nn.Module): + def __init__(self): + super(BaseModule, self).__init__() + + @property + def nparams(self): + """ + Returns number of trainable parameters of the module. + """ + num_params = 0 + for name, param in self.named_parameters(): + if param.requires_grad: + num_params += np.prod(param.detach().cpu().numpy().shape) + return num_params + + + def relocate_input(self, x: list): + """ + Relocates provided tensors to the same device set for the module. + """ + device = next(self.parameters()).device + for i in range(len(x)): + if isinstance(x[i], torch.Tensor) and x[i].device != device: + x[i] = x[i].to(device) + return x diff --git a/grad/diffusion.py b/grad/diffusion.py new file mode 100644 index 0000000..3462999 --- /dev/null +++ b/grad/diffusion.py @@ -0,0 +1,273 @@ +import math +import torch +from einops import rearrange +from grad.base import BaseModule + + +class Mish(BaseModule): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Upsample(BaseModule): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(BaseModule): + def __init__(self, dim): + super(Downsample, self).__init__() + self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(BaseModule): + def __init__(self, fn): + super(Rezero, self).__init__() + self.fn = fn + self.g = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class Block(BaseModule): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, + padding=1), torch.nn.GroupNorm( + groups, dim_out), Mish()) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock(BaseModule): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super(ResnetBlock, self).__init__() + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, + dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) + else: + self.res_conv = torch.nn.Identity() + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class LinearAttention(BaseModule): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class Residual(BaseModule): + def __init__(self, fn): + super(Residual, self).__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + output = self.fn(x, *args, **kwargs) + x + return output + + +class SinusoidalPosEmb(BaseModule): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class GradLogPEstimator2d(BaseModule): + def __init__(self, dim, dim_mults=(1, 2, 4), emb_dim=64, n_mels=100, + groups=8, pe_scale=1000): + super(GradLogPEstimator2d, self).__init__() + self.dim = dim + self.dim_mults = dim_mults + self.emb_dim = emb_dim + self.groups = groups + self.pe_scale = pe_scale + + self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim * 4), Mish(), + torch.nn.Linear(emb_dim * 4, n_mels)) + self.time_pos_emb = SinusoidalPosEmb(dim) + self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), + torch.nn.Linear(dim * 4, dim)) + + dims = [2 + 1, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + self.downs = torch.nn.ModuleList([]) + self.ups = torch.nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): # 2 downs + is_last = ind >= (num_resolutions - 1) + self.downs.append(torch.nn.ModuleList([ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else torch.nn.Identity()])) + + mid_dim = dims[-1] + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): # 2 ups + self.ups.append(torch.nn.ModuleList([ + ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in)])) + self.final_block = Block(dim, dim) + self.final_conv = torch.nn.Conv2d(dim, 1, 1) + + def forward(self, spk, x, mask, mu, t): + s = self.spk_mlp(spk) + + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) + x = torch.stack([mu, x, s], 1) + mask = mask.unsqueeze(1) + + hiddens = [] + masks = [mask] + for resnet1, resnet2, attn, downsample in self.downs: + mask_down = masks[-1] + x = resnet1(x, mask_down, t) + x = resnet2(x, mask_down, t) + x = attn(x) + hiddens.append(x) + x = downsample(x * mask_down) + masks.append(mask_down[:, :, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + x = self.mid_block1(x, mask_mid, t) + x = self.mid_attn(x) + x = self.mid_block2(x, mask_mid, t) + + for resnet1, resnet2, attn, upsample in self.ups: + mask_up = masks.pop() + x = torch.cat((x, hiddens.pop()), dim=1) + x = resnet1(x, mask_up, t) + x = resnet2(x, mask_up, t) + x = attn(x) + x = upsample(x * mask_up) + + x = self.final_block(x, mask) + output = self.final_conv(x * mask) + + return (output * mask).squeeze(1) + + +def get_noise(t, beta_init, beta_term, cumulative=False): + if cumulative: + noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) + else: + noise = beta_init + (beta_term - beta_init)*t + return noise + + +class Diffusion(BaseModule): + def __init__(self, n_mels, dim, emb_dim=64, + beta_min=0.05, beta_max=20, pe_scale=1000): + super(Diffusion, self).__init__() + self.n_mels = n_mels + self.beta_min = beta_min + self.beta_max = beta_max + self.estimator = GradLogPEstimator2d(dim, + n_mels=n_mels, + emb_dim=emb_dim, + pe_scale=pe_scale) + + def forward_diffusion(self, mel, mask, mu, t): + time = t.unsqueeze(-1).unsqueeze(-1) + cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) + mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) + variance = 1.0 - torch.exp(-cum_noise) + z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, + requires_grad=False) + xt = mean + z * torch.sqrt(variance) + return xt * mask, z * mask + + @torch.no_grad() + def reverse_diffusion(self, spk, z, mask, mu, n_timesteps, stoc=False): + h = 1.0 / n_timesteps + xt = z * mask + for i in range(n_timesteps): + t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, + device=z.device) + time = t.unsqueeze(-1).unsqueeze(-1) + noise_t = get_noise(time, self.beta_min, self.beta_max, + cumulative=False) + if stoc: # adds stochastic term + dxt_det = 0.5 * (mu - xt) - self.estimator(spk, xt, mask, mu, t) + dxt_det = dxt_det * noise_t * h + dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, + requires_grad=False) + dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) + dxt = dxt_det + dxt_stoc + else: + dxt = 0.5 * (mu - xt - self.estimator(spk, xt, mask, mu, t)) + dxt = dxt * noise_t * h + xt = (xt - dxt) * mask + return xt + + @torch.no_grad() + def forward(self, spk, z, mask, mu, n_timesteps, stoc=False): + return self.reverse_diffusion(spk, z, mask, mu, n_timesteps, stoc) + + def loss_t(self, spk, mel, mask, mu, t): + xt, z = self.forward_diffusion(mel, mask, mu, t) + time = t.unsqueeze(-1).unsqueeze(-1) + cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) + noise_estimation = self.estimator(spk, xt, mask, mu, t) + noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) + loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_mels) + return loss, xt + + def compute_loss(self, spk, mel, mask, mu, offset=1e-5): + t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False) + t = torch.clamp(t, offset, 1.0 - offset) + return self.loss_t(spk, mel, mask, mu, t) diff --git a/grad/encoder.py b/grad/encoder.py new file mode 100644 index 0000000..cb1e118 --- /dev/null +++ b/grad/encoder.py @@ -0,0 +1,300 @@ +import math +import torch + +from grad.base import BaseModule +from grad.utils import sequence_mask, convert_pad_shape + + +class LayerNorm(BaseModule): + def __init__(self, channels, eps=1e-4): + super(LayerNorm, self).__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean)**2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(BaseModule): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, + n_layers, p_dropout): + super(ConvReluNorm, self).__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_pre = torch.nn.Conv1d(in_channels, hidden_channels, + kernel_size, padding=kernel_size//2) + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x = self.conv_pre(x) + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class MultiHeadAttention(BaseModule): + def __init__(self, channels, out_channels, n_heads, window_size=None, + heads_share=True, p_dropout=0.0, proximal_bias=False, + proximal_init=False): + super(MultiHeadAttention, self).__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, + dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, + value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = torch.nn.functional.pad( + relative_embeddings, convert_pad_shape([[0, 0], + [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(BaseModule): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, + p_dropout=0.0): + super(FFN, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, + padding=kernel_size//2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, + padding=kernel_size//2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(BaseModule): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, + kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): + super(Encoder, self).__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, + n_heads, window_size=window_size, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, + filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(BaseModule): + def __init__(self, n_vecs, n_mels, n_embs, + n_channels, + filter_channels, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + window_size=4): + super(TextEncoder, self).__init__() + self.n_vecs = n_vecs + self.n_mels = n_mels + self.n_embs = n_embs + self.n_channels = n_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.prenet = ConvReluNorm(n_vecs, + n_channels, + n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5) + + self.encoder = Encoder(n_channels + n_embs + n_embs, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + window_size=window_size) + + self.proj_m = torch.nn.Conv1d(n_channels + n_embs + n_embs, n_mels, 1) + + def forward(self, x_lengths, x, pit, spk): + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + # despeaker + x = self.prenet(x, x_mask) + # pitch + speaker + spk = spk.unsqueeze(-1).repeat(1, 1, x.shape[-1]) + x = torch.cat([x, pit], dim=1) + x = torch.cat([x, spk], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + return mu, x_mask diff --git a/grad/model.py b/grad/model.py new file mode 100644 index 0000000..9ce962e --- /dev/null +++ b/grad/model.py @@ -0,0 +1,125 @@ +import math +import torch + +from grad.base import BaseModule +from grad.encoder import TextEncoder +from grad.diffusion import Diffusion +from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments + + +class GradTTS(BaseModule): + def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs, + n_enc_channels, filter_channels, + dec_dim, beta_min, beta_max, pe_scale): + super(GradTTS, self).__init__() + # common + self.n_mels = n_mels + self.n_vecs = n_vecs + self.n_spks = n_spks + self.n_embs = n_embs + # encoder + self.n_enc_channels = n_enc_channels + self.filter_channels = filter_channels + # decoder + self.dec_dim = dec_dim + self.beta_min = beta_min + self.beta_max = beta_max + self.pe_scale = pe_scale + + self.pit_emb = torch.nn.Embedding(n_pits, n_embs) + self.spk_emb = torch.nn.Linear(n_spks, n_embs) + self.encoder = TextEncoder(n_vecs, + n_mels, + n_embs, + n_enc_channels, + filter_channels) + self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale) + + @torch.no_grad() + def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False): + """ + Generates mel-spectrogram from vec. Returns: + 1. encoder outputs + 2. decoder outputs + + Args: + lengths (torch.Tensor): lengths of texts in batch. + vec (torch.Tensor): batch of speech vec + pit (torch.Tensor): batch of speech pit + spk (torch.Tensor): batch of speaker + + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + stoc (bool, optional): flag that adds stochastic term to the decoder sampler. + Usually, does not provide synthesis improvements. + """ + lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk]) + + # Get pitch embedding + pit = self.pit_emb(f0_to_coarse(pit)) + + # Get speaker embedding + spk = self.spk_emb(spk) + + # Transpose + vec = torch.transpose(vec, 1, -1) + pit = torch.transpose(pit, 1, -1) + + # Get encoder_outputs `mu_x` + mu_x, mask_x = self.encoder(lengths, vec, pit, spk) + encoder_outputs = mu_x + + # Sample latent representation from terminal distribution N(mu_y, I) + z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature + # Generate sample by performing reverse dynamics + decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc) + + return encoder_outputs, decoder_outputs + + def compute_loss(self, lengths, vec, pit, spk, mel, out_size): + """ + Computes 2 losses: + 1. prior loss: loss between mel-spectrogram and encoder outputs. + 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. + + Args: + lengths (torch.Tensor): lengths of texts in batch. + vec (torch.Tensor): batch of speech vec + pit (torch.Tensor): batch of speech pit + spk (torch.Tensor): batch of speaker + mel (torch.Tensor): batch of corresponding mel-spectrogram + + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + """ + lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel]) + + # Get pitch embedding + pit = self.pit_emb(f0_to_coarse(pit)) + + # Get speaker embedding + spk = self.spk_emb(spk) + + # Transpose + vec = torch.transpose(vec, 1, -1) + pit = torch.transpose(pit, 1, -1) + + # Get encoder_outputs `mu_x` + mu_x, mask_x = self.encoder(lengths, vec, pit, spk) + + # Cut a small segment of mel-spectrogram in order to increase batch size + if not isinstance(out_size, type(None)): + ids = rand_ids_segments(lengths, out_size) + mel = slice_segments(mel, ids, out_size) + + mask_y = slice_segments(mask_x, ids, out_size) + mu_y = slice_segments(mu_x, ids, out_size) + + # Compute loss of score-based decoder + diff_loss, xt = self.decoder.compute_loss(spk, mel, mask_y, mu_y) + + # Compute loss between aligned encoder outputs and mel-spectrogram + prior_loss = torch.sum(0.5 * ((mel - mu_y) ** 2 + math.log(2 * math.pi)) * mask_y) + prior_loss = prior_loss / (torch.sum(mask_y) * self.n_mels) + + return prior_loss, diff_loss diff --git a/grad/utils.py b/grad/utils.py new file mode 100644 index 0000000..8432525 --- /dev/null +++ b/grad/utils.py @@ -0,0 +1,99 @@ +import torch +import numpy as np +import inspect + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + while True: + if length % (2**num_downsamplings_in_unet) == 0: + return length + length += 1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], + [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) + return loss + + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \ + np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \ + (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = ( + f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) + assert f0_coarse.max() <= 255 and f0_coarse.min( + ) >= 1, (f0_coarse.max(), f0_coarse.min()) + return f0_coarse + + +def rand_ids_segments(lengths, segment_size=200): + b = lengths.shape[0] + ids_str_max = lengths - segment_size + ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(dtype=torch.long) + return ids_str + + +def slice_segments(x, ids_str, segment_size=200): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def retrieve_name(var): + for fi in reversed(inspect.stack()): + names = [var_name for var_name, + var_val in fi.frame.f_locals.items() if var_val is var] + if len(names) > 0: + return names[0] + + +Debug_Enable = True + + +def debug_shapes(var): + if Debug_Enable: + print(retrieve_name(var), var.shape) diff --git a/grad_extend/data.py b/grad_extend/data.py new file mode 100644 index 0000000..b75625a --- /dev/null +++ b/grad_extend/data.py @@ -0,0 +1,133 @@ +import os +import random +import numpy as np + +import torch + +from grad.utils import fix_len_compatibility +from grad_extend.utils import parse_filelist + + +class TextMelSpeakerDataset(torch.utils.data.Dataset): + def __init__(self, filelist_path): + super().__init__() + self.filelist = parse_filelist(filelist_path, split_char='|') + self._filter() + print(f'----------{len(self.filelist)}----------') + + def _filter(self): + items_new = [] + # segment = 200 + items_min = 250 # 10ms * 250 = 2.5 S + items_max = 500 # 10ms * 400 = 5.0 S + for mel, vec, pit, spk in self.filelist: + if not os.path.isfile(mel): + continue + if not os.path.isfile(vec): + continue + if not os.path.isfile(pit): + continue + if not os.path.isfile(spk): + continue + temp = np.load(pit) + usel = int(temp.shape[0] - 1) # useful length + if (usel < items_min): + continue + if (usel >= items_max): + usel = items_max + items_new.append([mel, vec, pit, spk, usel]) + self.filelist = items_new + + def get_triplet(self, item): + # print(item) + mel = item[0] + vec = item[1] + pit = item[2] + spk = item[3] + use = item[4] + + mel = torch.load(mel) + vec = np.load(vec) + vec = np.repeat(vec, 2, 0) # 320 VEC -> 160 * 2 + pit = np.load(pit) + spk = np.load(spk) + + vec = torch.FloatTensor(vec) + pit = torch.FloatTensor(pit) + spk = torch.FloatTensor(spk) + + len_vec = vec.size()[0] - 2 # for safe + len_pit = pit.size()[0] + len_min = min(len_pit, len_vec) + + mel = mel[:, :len_min] + vec = vec[:len_min, :] + pit = pit[:len_min] + + if len_min > use: + max_frame_start = vec.size(0) - use - 1 + frame_start = random.randint(0, max_frame_start) + frame_end = frame_start + use + + mel = mel[:, frame_start:frame_end] + vec = vec[frame_start:frame_end, :] + pit = pit[frame_start:frame_end] + # print(mel.shape) + # print(vec.shape) + # print(pit.shape) + # print(spk.shape) + return (mel, vec, pit, spk) + + def __getitem__(self, index): + mel, vec, pit, spk = self.get_triplet(self.filelist[index]) + item = {'mel': mel, 'vec': vec, 'pit': pit, 'spk': spk} + return item + + def __len__(self): + return len(self.filelist) + + def sample_test_batch(self, size): + idx = np.random.choice(range(len(self)), size=size, replace=False) + test_batch = [] + for index in idx: + test_batch.append(self.__getitem__(index)) + return test_batch + + +class TextMelSpeakerBatchCollate(object): + # mel: [freq, length] + # vec: [len, 256] + # pit: [len] + # spk: [256] + def __call__(self, batch): + B = len(batch) + mel_max_length = max([item['mel'].shape[-1] for item in batch]) + max_length = fix_len_compatibility(mel_max_length) + + d_mel = batch[0]['mel'].shape[0] + d_vec = batch[0]['vec'].shape[1] + d_spk = batch[0]['spk'].shape[0] + # print("d_mel", d_mel) + # print("d_vec", d_vec) + # print("d_spk", d_spk) + mel = torch.zeros((B, d_mel, max_length), dtype=torch.float32) + vec = torch.zeros((B, max_length, d_vec), dtype=torch.float32) + pit = torch.zeros((B, max_length), dtype=torch.float32) + spk = torch.zeros((B, d_spk), dtype=torch.float32) + lengths = torch.LongTensor(B) + + for i, item in enumerate(batch): + y_, x_, p_, s_ = item['mel'], item['vec'], item['pit'], item['spk'] + + mel[i, :, :y_.shape[1]] = y_ + vec[i, :x_.shape[0], :] = x_ + pit[i, :p_.shape[0]] = p_ + spk[i] = s_ + + lengths[i] = y_.shape[1] + # print("lengths", lengths.shape) + # print("vec", vec.shape) + # print("pit", pit.shape) + # print("spk", spk.shape) + # print("mel", mel.shape) + return {'lengths': lengths, 'vec': vec, 'pit': pit, 'spk': spk, 'mel': mel} diff --git a/grad_extend/train.py b/grad_extend/train.py new file mode 100644 index 0000000..32301ed --- /dev/null +++ b/grad_extend/train.py @@ -0,0 +1,163 @@ +import os +import torch +import numpy as np + +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from tqdm import tqdm +from grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate +from grad_extend.utils import plot_tensor, save_plot, load_model +from grad.utils import fix_len_compatibility +from grad.model import GradTTS + + +# 200 frames +out_size = fix_len_compatibility(200) + + +def train(hps, chkpt_path=None): + + print('Initializing logger...') + logger = SummaryWriter(log_dir=hps.train.log_dir) + + print('Initializing data loaders...') + train_dataset = TextMelSpeakerDataset(hps.train.train_files) + batch_collate = TextMelSpeakerBatchCollate() + loader = DataLoader(dataset=train_dataset, + batch_size=hps.train.batch_size, + collate_fn=batch_collate, + drop_last=True, + num_workers=8, + shuffle=True) + test_dataset = TextMelSpeakerDataset(hps.train.valid_files) + + print('Initializing model...') + model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs, + hps.grad.n_enc_channels, hps.grad.filter_channels, + hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale).cuda() + print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6)) + print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6)) + + # Load Pretrain + if os.path.isfile(hps.train.pretrain): + logger.info("Start from Grad_SVC pretrain model: %s" % hps.train.pretrain) + checkpoint = torch.load(hps.train.pretrain, map_location='cpu') + load_model(model, checkpoint['model']) + + print('Initializing optimizer...') + optim = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate) + + initepoch = 1 + iteration = 0 + + # Load Continue + if chkpt_path is not None: + logger.info("Resuming from checkpoint: %s" % chkpt_path) + checkpoint = torch.load(chkpt_path, map_location='cpu') + model.load_state_dict(checkpoint['model']) + optim.load_state_dict(checkpoint['optim']) + initepoch = checkpoint['epoch'] + iteration = checkpoint['steps'] + + print('Logging test batch...') + test_batch = test_dataset.sample_test_batch(size=hps.train.test_size) + for i, item in enumerate(test_batch): + mel = item['mel'] + logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), + global_step=0, dataformats='HWC') + save_plot(mel.squeeze(), f'{hps.train.log_dir}/original_{i}.png') + + print('Start training...') + + for epoch in range(initepoch, hps.train.n_epochs + 1): + model.eval() + print('Synthesis...') + + if epoch % hps.train.test_step == 0: + with torch.no_grad(): + for i, item in enumerate(test_batch): + l_vec = item['vec'].shape[0] + d_vec = item['vec'].shape[1] + + lengths_fix = fix_len_compatibility(l_vec) + lengths = torch.LongTensor([l_vec]).cuda() + + vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).cuda() + pit = torch.zeros((1, lengths_fix), dtype=torch.float32).cuda() + spk = item['spk'].to(torch.float32).unsqueeze(0).cuda() + vec[0, :l_vec, :] = item['vec'] + pit[0, :l_vec] = item['pit'] + + y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50) + + logger.add_image(f'image_{i}/generated_enc', + plot_tensor(y_enc.squeeze().cpu()), + global_step=iteration, dataformats='HWC') + logger.add_image(f'image_{i}/generated_dec', + plot_tensor(y_dec.squeeze().cpu()), + global_step=iteration, dataformats='HWC') + save_plot(y_enc.squeeze().cpu(), + f'{hps.train.log_dir}/generated_enc_{i}.png') + save_plot(y_dec.squeeze().cpu(), + f'{hps.train.log_dir}/generated_dec_{i}.png') + + model.train() + + prior_losses = [] + diff_losses = [] + with tqdm(loader, total=len(train_dataset)//hps.train.batch_size) as progress_bar: + for batch in progress_bar: + model.zero_grad() + + lengths = batch['lengths'].cuda() + vec = batch['vec'].cuda() + pit = batch['pit'].cuda() + spk = batch['spk'].cuda() + mel = batch['mel'].cuda() + + prior_loss, diff_loss = model.compute_loss(lengths, vec, pit, spk, + mel, out_size=out_size) + loss = sum([prior_loss, diff_loss]) + loss.backward() + + enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), + max_norm=1) + dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), + max_norm=1) + optim.step() + + logger.add_scalar('training/prior_loss', prior_loss, + global_step=iteration) + logger.add_scalar('training/diffusion_loss', diff_loss, + global_step=iteration) + logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, + global_step=iteration) + logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, + global_step=iteration) + + msg = f'Epoch: {epoch}, iteration: {iteration} | prior_loss: {prior_loss.item():.3f}, diff_loss: {diff_loss.item():.3f}' + progress_bar.set_description(msg) + + prior_losses.append(prior_loss.item()) + diff_losses.append(diff_loss.item()) + iteration += 1 + + msg = 'Epoch %d: ' % (epoch) + msg += '| prior loss = %.3f ' % np.mean(prior_losses) + msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) + with open(f'{hps.train.log_dir}/train.log', 'a') as f: + f.write(msg) + + if epoch % hps.train.save_step > 0: + continue + + save_path = f"{hps.train.log_dir}/grad_svc_{epoch}.pt" + torch.save({ + 'model': model.state_dict(), + 'optim': optim.state_dict(), + 'epoch': epoch, + 'steps': iteration, + + }, save_path) + logger.info("Saved checkpoint to: %s" % save_path) diff --git a/grad_extend/utils.py b/grad_extend/utils.py new file mode 100644 index 0000000..d7cf538 --- /dev/null +++ b/grad_extend/utils.py @@ -0,0 +1,73 @@ +import os +import glob +import numpy as np +import matplotlib.pyplot as plt + +import torch + + +def parse_filelist(filelist_path, split_char="|"): + with open(filelist_path, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split_char) for line in f] + return filepaths_and_text + + +def load_model(model, saved_state_dict): + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + print("%s is not in the checkpoint" % k) + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + return model + + +def latest_checkpoint_path(dir_path, regex="grad_svc_*.pt"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + return x + + +def load_checkpoint(logdir, model, num=None): + if num is None: + model_path = latest_checkpoint_path(logdir, regex="grad_svc_*.pt") + else: + model_path = os.path.join(logdir, f"grad_svc_{num}.pt") + print(f'Loading checkpoint {model_path}...') + model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) + model.load_state_dict(model_dict, strict=False) + return model + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use('default') + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use('default') + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + return diff --git a/grad_pretrain/README.md b/grad_pretrain/README.md new file mode 100644 index 0000000..8811790 --- /dev/null +++ b/grad_pretrain/README.md @@ -0,0 +1,3 @@ +Path for: + + gvc.pretrain.pth \ No newline at end of file diff --git a/gvc_export.py b/gvc_export.py new file mode 100644 index 0000000..7c58e75 --- /dev/null +++ b/gvc_export.py @@ -0,0 +1,48 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import torch +import argparse +from omegaconf import OmegaConf +from grad.model import GradTTS + + +def load_model(checkpoint_path, model): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + saved_state_dict = checkpoint_dict["model"] + + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + print("%s is not in the checkpoint" % k) + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + + +def main(args): + hps = OmegaConf.load(args.config) + + print('Initializing Grad-TTS...') + model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs, + hps.grad.n_enc_channels, hps.grad.filter_channels, + hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale) + print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6)) + print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6)) + + load_model(args.checkpoint_path, model) + torch.save({'model': model.state_dict()}, "gvc.pth") + torch.save({'model': model.state_dict()}, "gvc.pretrain.pth") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default='./configs/base.yaml', + help="yaml file for config.") + parser.add_argument('-p', '--checkpoint_path', type=str, required=True, + help="path of checkpoint pt file for evaluation") + args = parser.parse_args() + + main(args) diff --git a/gvc_inference.py b/gvc_inference.py new file mode 100644 index 0000000..332b984 --- /dev/null +++ b/gvc_inference.py @@ -0,0 +1,153 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import torch +import argparse +import numpy as np + +from omegaconf import OmegaConf +from pitch import load_csv_pitch +from spec.inference import print_mel + +from grad.utils import fix_len_compatibility +from grad.model import GradTTS + + +def load_gvc_model(checkpoint_path, model): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + saved_state_dict = checkpoint_dict + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + print("%s is not in the checkpoint" % k) + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + return model + + +@torch.no_grad() +def gvc_main(device, model, _vec, _pit, spk): + l_vec = _vec.shape[0] + d_vec = _vec.shape[1] + lengths_fix = fix_len_compatibility(l_vec) + lengths = torch.LongTensor([l_vec]).to(device) + vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).to(device) + pit = torch.zeros((1, lengths_fix), dtype=torch.float32).to(device) + vec[0, :l_vec, :] = _vec + pit[0, :l_vec] = _pit + y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50) + y_dec = y_dec.squeeze(0) + y_dec = y_dec[:, :l_vec] + return y_dec + + +def main(args): + + if (args.vec == None): + args.vec = "gvc_tmp.vec.npy" + print( + f"Auto run : python hubert/inference.py -w {args.wave} -v {args.vec}") + os.system(f"python hubert/inference.py -w {args.wave} -v {args.vec}") + + if (args.pit == None): + args.pit = "gvc_tmp.pit.csv" + print( + f"Auto run : python pitch/inference.py -w {args.wave} -p {args.pit}") + os.system(f"python pitch/inference.py -w {args.wave} -p {args.pit}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + hps = OmegaConf.load(args.config) + + print('Initializing Grad-TTS...') + model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs, + hps.grad.n_enc_channels, hps.grad.filter_channels, + hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale) + print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6)) + print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6)) + + load_gvc_model(args.model, model) + model.eval() + model.to(device) + + spk = np.load(args.spk) + spk = torch.FloatTensor(spk) + + vec = np.load(args.vec) + vec = np.repeat(vec, 2, 0) + vec = torch.FloatTensor(vec) + + pit = load_csv_pitch(args.pit) + pit = np.array(pit) + pit = pit * 2 ** (args.shift / 12) + pit = torch.FloatTensor(pit) + + len_pit = pit.size()[0] + len_vec = vec.size()[0] + len_min = min(len_pit, len_vec) + pit = pit[:len_min] + vec = vec[:len_min, :] + + with torch.no_grad(): + spk = spk.unsqueeze(0).to(device) + + all_frame = len_min + hop_frame = 8 + out_chunk = 2400 # 24 S + out_index = 0 + out_mel = None + + while (out_index < all_frame): + if (out_index == 0): # start frame + cut_s = 0 + cut_s_out = 0 + else: + cut_s = out_index - hop_frame + cut_s_out = hop_frame + + if (out_index + out_chunk + hop_frame > all_frame): # end frame + cut_e = all_frame + cut_e_out = -1 + else: + cut_e = out_index + out_chunk + hop_frame + cut_e_out = -1 * hop_frame + + sub_vec = vec[cut_s:cut_e, :].to(device) + sub_pit = pit[cut_s:cut_e].to(device) + + sub_out = gvc_main(device, model, sub_vec, sub_pit, spk) + sub_out = sub_out[:, cut_s_out:cut_e_out] + + out_index = out_index + out_chunk + if out_mel == None: + out_mel = sub_out + else: + out_mel = torch.cat((out_mel, sub_out), -1) + if cut_e == all_frame: + break + + torch.save(out_mel, "gvc_out.mel.pt") + print_mel(out_mel, "gvc_out.mel.png") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default='./configs/base.yaml', + help="yaml file for config.") + parser.add_argument('--model', type=str, required=True, + help="path of model for evaluation") + parser.add_argument('--wave', type=str, required=True, + help="Path of raw audio.") + parser.add_argument('--spk', type=str, required=True, + help="Path of speaker.") + parser.add_argument('--vec', type=str, + help="Path of hubert vector.") + parser.add_argument('--pit', type=str, + help="Path of pitch csv file.") + parser.add_argument('--shift', type=int, default=0, + help="Pitch shift key.") + args = parser.parse_args() + + main(args) diff --git a/gvc_inference_wave.py b/gvc_inference_wave.py new file mode 100644 index 0000000..9ed17b8 --- /dev/null +++ b/gvc_inference_wave.py @@ -0,0 +1,71 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import argparse + +from omegaconf import OmegaConf +from scipy.io.wavfile import write +from bigvgan.model.generator import Generator +from pitch import load_csv_pitch + + +def load_bigv_model(checkpoint_path, model): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + saved_state_dict = checkpoint_dict["model_g"] + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + print("%s is not in the checkpoint" % k) + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + return model + + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + hp = OmegaConf.load(args.config) + model = Generator(hp) + load_bigv_model(args.model, model) + model.eval() + model.to(device) + + mel = torch.load(args.mel) + + pit = load_csv_pitch(args.pit) + pit = torch.FloatTensor(pit) + + len_pit = pit.size()[0] + len_mel = mel.size()[1] + len_min = min(len_pit, len_mel) + pit = pit[:len_min] + mel = mel[:, :len_min] + + with torch.no_grad(): + mel = mel.unsqueeze(0).to(device) + pit = pit.unsqueeze(0).to(device) + audio = model.inference(mel, pit) + audio = audio.cpu().detach().numpy() + + pitwav = model.pitch2wav(pit) + pitwav = pitwav.cpu().detach().numpy() + + write("gvc_out.wav", hp.audio.sampling_rate, audio) + write("gvc_pitch.wav", hp.audio.sampling_rate, pitwav) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--mel', type=str, + help="Path of content vector.") + parser.add_argument('--pit', type=str, + help="Path of pitch csv file.") + args = parser.parse_args() + + args.config = "./bigvgan/configs/nsf_bigvgan.yaml" + args.model = "./bigvgan_pretrain/nsf_bigvgan_pretrain_32K.pth" + + main(args) diff --git a/gvc_trainer.py b/gvc_trainer.py new file mode 100644 index 0000000..4802ada --- /dev/null +++ b/gvc_trainer.py @@ -0,0 +1,30 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import argparse +import torch +import numpy as np + +from omegaconf import OmegaConf +from grad_extend.train import train + +torch.backends.cudnn.benchmark = True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default='./configs/base.yaml', + help="yaml file for configuration") + parser.add_argument('-p', '--checkpoint_path', type=str, default=None, + help="path of checkpoint pt file to resume training") + args = parser.parse_args() + + assert torch.cuda.is_available() + print('Numbers of GPU :', torch.cuda.device_count()) + + hps = OmegaConf.load(args.config) + + np.random.seed(hps.train.seed) + torch.manual_seed(hps.train.seed) + torch.cuda.manual_seed(hps.train.seed) + + train(hps, args.checkpoint_path) diff --git a/hubert/__init__.py b/hubert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hubert/hubert_model.py b/hubert/hubert_model.py new file mode 100644 index 0000000..09df44c --- /dev/null +++ b/hubert/hubert_model.py @@ -0,0 +1,229 @@ +import copy +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as t_func + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, mask = self.encode(x) + x = self.proj(x) + logits = self.logits(x) + return logits, mask + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = t_func.gelu(self.norm0(self.conv0(x))) + x = t_func.gelu(self.conv1(x)) + x = t_func.gelu(self.conv2(x)) + x = t_func.gelu(self.conv3(x)) + x = t_func.gelu(self.conv4(x)) + x = t_func.gelu(self.conv5(x)) + x = t_func.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = t_func.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def consume_prefix(state_dict, prefix: str) -> None: + keys = sorted(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix):] + state_dict[newkey] = state_dict.pop(key) + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = torch.load(path) + consume_prefix(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert diff --git a/hubert/inference.py b/hubert/inference.py new file mode 100644 index 0000000..a40bdeb --- /dev/null +++ b/hubert/inference.py @@ -0,0 +1,67 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import numpy as np +import argparse +import torch +import librosa + +from hubert import hubert_model + + +def load_audio(file: str, sr: int = 16000): + x, sr = librosa.load(file, sr=sr) + return x + + +def load_model(path, device): + model = hubert_model.hubert_soft(path) + model.eval() + if not (device == "cpu"): + model.half() + model.to(device) + return model + + +def pred_vec(model, wavPath, vecPath, device): + audio = load_audio(wavPath) + audln = audio.shape[0] + vec_a = [] + idx_s = 0 + while (idx_s + 20 * 16000 < audln): + feats = audio[idx_s:idx_s + 20 * 16000] + feats = torch.from_numpy(feats).to(device) + feats = feats[None, None, :] + if not (device == "cpu"): + feats = feats.half() + with torch.no_grad(): + vec = model.units(feats).squeeze().data.cpu().float().numpy() + vec_a.extend(vec) + idx_s = idx_s + 20 * 16000 + if (idx_s < audln): + feats = audio[idx_s:audln] + feats = torch.from_numpy(feats).to(device) + feats = feats[None, None, :] + if not (device == "cpu"): + feats = feats.half() + with torch.no_grad(): + vec = model.units(feats).squeeze().data.cpu().float().numpy() + # print(vec.shape) # [length, dim=256] hop=320 + vec_a.extend(vec) + np.save(vecPath, vec_a, allow_pickle=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav") + parser.add_argument("-v", "--vec", help="vec", dest="vec") + args = parser.parse_args() + print(args.wav) + print(args.vec) + + wavPath = args.wav + vecPath = args.vec + + device = "cuda" if torch.cuda.is_available() else "cpu" + hubert = load_model(os.path.join( + "hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device) + pred_vec(hubert, wavPath, vecPath, device) diff --git a/hubert_pretrain/README.md b/hubert_pretrain/README.md new file mode 100644 index 0000000..dbecfeb --- /dev/null +++ b/hubert_pretrain/README.md @@ -0,0 +1,3 @@ +Path for: + + hubert-soft-0d54a1f4.pt \ No newline at end of file diff --git a/pitch/__init__.py b/pitch/__init__.py new file mode 100644 index 0000000..bc41814 --- /dev/null +++ b/pitch/__init__.py @@ -0,0 +1 @@ +from .inference import load_csv_pitch \ No newline at end of file diff --git a/pitch/inference.py b/pitch/inference.py new file mode 100644 index 0000000..f5c6309 --- /dev/null +++ b/pitch/inference.py @@ -0,0 +1,54 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import librosa +import argparse +import numpy as np +import parselmouth +# pip install praat-parselmouth + +def compute_f0_mouth(path): + x, sr = librosa.load(path, sr=16000) + assert sr == 16000 + lpad = 1024 // 160 + rpad = lpad + f0 = parselmouth.Sound(x, sr).to_pitch_ac( + time_step=160 / sr, + voicing_threshold=0.5, + pitch_floor=30, + pitch_ceiling=1000).selected_array['frequency'] + f0 = np.pad(f0, [[lpad, rpad]], mode='constant') + return f0 + + +def save_csv_pitch(pitch, path): + with open(path, "w", encoding='utf-8') as pitch_file: + for i in range(len(pitch)): + t = i * 10 + minute = t // 60000 + seconds = (t - minute * 60000) // 1000 + millisecond = t % 1000 + print( + f"{minute}m {seconds}s {millisecond:3d},{int(pitch[i])}", file=pitch_file) + + +def load_csv_pitch(path): + pitch = [] + with open(path, "r", encoding='utf-8') as pitch_file: + for line in pitch_file.readlines(): + pit = line.strip().split(",")[-1] + pitch.append(int(pit)) + return pitch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav") + parser.add_argument("-p", "--pit", help="pit", dest="pit") # csv for excel + args = parser.parse_args() + print(args.wav) + print(args.pit) + + pitch = compute_f0_mouth(args.wav) + save_csv_pitch(pitch, args.pit) + #tmp = load_csv_pitch(args.pit) + #save_csv_pitch(tmp, "tmp.csv") diff --git a/prepare/preprocess_a.py b/prepare/preprocess_a.py new file mode 100644 index 0000000..87d03b5 --- /dev/null +++ b/prepare/preprocess_a.py @@ -0,0 +1,58 @@ +import os +import librosa +import argparse +import numpy as np +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed +from scipy.io import wavfile + + +def resample_wave(wav_in, wav_out, sample_rate): + wav, _ = librosa.load(wav_in, sr=sample_rate) + wav = wav / np.abs(wav).max() * 0.6 + wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6 + wavfile.write(wav_out, sample_rate, wav.astype(np.int16)) + + +def process_file(file, wavPath, spks, outPath, sr): + if file.endswith(".wav"): + file = file[:-4] + resample_wave(f"{wavPath}/{spks}/{file}.wav", f"{outPath}/{spks}/{file}.wav", sr) + + +def process_files_with_thread_pool(wavPath, spks, outPath, sr, thread_num=None): + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + + with ThreadPoolExecutor(max_workers=thread_num) as executor: + futures = {executor.submit(process_file, file, wavPath, spks, outPath, sr): file for file in files} + + for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr} {spks}'): + future.result() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-o", "--out", help="out", dest="out", required=True) + parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + + args = parser.parse_args() + print(args.wav) + print(args.out) + print(args.sr) + + os.makedirs(args.out, exist_ok=True) + wavPath = args.wav + outPath = args.out + + assert args.sr == 16000 or args.sr == 32000 + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{outPath}/{spks}", exist_ok=True) + if args.thread_count == 0: + process_num = os.cpu_count() // 2 + 1 + else: + process_num = args.thread_count + process_files_with_thread_pool(wavPath, spks, outPath, args.sr, process_num) diff --git a/prepare/preprocess_f0.py b/prepare/preprocess_f0.py new file mode 100644 index 0000000..bdae0a0 --- /dev/null +++ b/prepare/preprocess_f0.py @@ -0,0 +1,64 @@ +import os +import numpy as np +import librosa +import argparse +import parselmouth +# pip install praat-parselmouth +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def compute_f0(path, save): + x, sr = librosa.load(path, sr=16000) + assert sr == 16000 + lpad = 1024 // 160 + rpad = lpad + f0 = parselmouth.Sound(x, sr).to_pitch_ac( + time_step=160 / sr, + voicing_threshold=0.5, + pitch_floor=30, + pitch_ceiling=1000).selected_array['frequency'] + f0 = np.pad(f0, [[lpad, rpad]], mode='constant') + for index, pitch in enumerate(f0): + f0[index] = round(pitch, 1) + np.save(save, f0, allow_pickle=False) + + +def process_file(file, wavPath, spks, pitPath): + if file.endswith(".wav"): + file = file[:-4] + compute_f0(f"{wavPath}/{spks}/{file}.wav", f"{pitPath}/{spks}/{file}.pit") + + +def process_files_with_process_pool(wavPath, spks, pitPath, process_num=None): + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + + with ProcessPoolExecutor(max_workers=process_num) as executor: + futures = {executor.submit(process_file, file, wavPath, spks, pitPath): file for file in files} + + for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing f0 {spks}'): + future.result() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + + args = parser.parse_args() + print(args.wav) + print(args.pit) + + os.makedirs(args.pit, exist_ok=True) + wavPath = args.wav + pitPath = args.pit + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{pitPath}/{spks}", exist_ok=True) + if args.thread_count == 0: + process_num = os.cpu_count() // 2 + 1 + else: + process_num = args.thread_count + process_files_with_process_pool(wavPath, spks, pitPath, process_num) diff --git a/prepare/preprocess_hubert.py b/prepare/preprocess_hubert.py new file mode 100644 index 0000000..dd4265b --- /dev/null +++ b/prepare/preprocess_hubert.py @@ -0,0 +1,58 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import numpy as np +import argparse +import torch +import librosa + +from tqdm import tqdm +from hubert import hubert_model + + +def load_audio(file: str, sr: int = 16000): + x, sr = librosa.load(file, sr=sr) + return x + + +def load_model(path, device): + model = hubert_model.hubert_soft(path) + model.eval() + model.half() + model.to(device) + return model + + +def pred_vec(model, wavPath, vecPath, device): + feats = load_audio(wavPath) + feats = torch.from_numpy(feats).to(device) + feats = feats[None, None, :].half() + with torch.no_grad(): + vec = model.units(feats).squeeze().data.cpu().float().numpy() + # print(vec.shape) # [length, dim=256] hop=320 + np.save(vecPath, vec, allow_pickle=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-v", "--vec", help="vec", dest="vec", required=True) + + args = parser.parse_args() + print(args.wav) + print(args.vec) + os.makedirs(args.vec, exist_ok=True) + + wavPath = args.wav + vecPath = args.vec + + device = "cuda" if torch.cuda.is_available() else "cpu" + hubert = load_model(os.path.join("hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device) + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{vecPath}/{spks}", exist_ok=True) + + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + for file in tqdm(files, desc=f'Processing vec {spks}'): + file = file[:-4] + pred_vec(hubert, f"{wavPath}/{spks}/{file}.wav", f"{vecPath}/{spks}/{file}.vec", device) diff --git a/prepare/preprocess_random.py b/prepare/preprocess_random.py new file mode 100644 index 0000000..f84977b --- /dev/null +++ b/prepare/preprocess_random.py @@ -0,0 +1,23 @@ +import random + + +if __name__ == "__main__": + all_items = [] + fo = open("./files/train_all.txt", "r+", encoding='utf-8') + while (True): + try: + item = fo.readline().strip() + except Exception as e: + print('nothing of except:', e) + break + if (item == None or item == ""): + break + all_items.append(item) + fo.close() + + random.shuffle(all_items) + + fw = open("./files/train_all.txt", "w", encoding="utf-8") + for strs in all_items: + print(strs, file=fw) + fw.close() diff --git a/prepare/preprocess_speaker.py b/prepare/preprocess_speaker.py new file mode 100644 index 0000000..797b60e --- /dev/null +++ b/prepare/preprocess_speaker.py @@ -0,0 +1,103 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import numpy as np +import argparse + +from tqdm import tqdm +from functools import partial +from argparse import RawTextHelpFormatter +from multiprocessing.pool import ThreadPool + +from speaker.models.lstm import LSTMSpeakerEncoder +from speaker.config import SpeakerEncoderConfig +from speaker.utils.audio import AudioProcessor +from speaker.infer import read_json + + +def get_spk_wavs(dataset_path, output_path): + wav_files = [] + os.makedirs(f"./{output_path}", exist_ok=True) + for spks in os.listdir(dataset_path): + if os.path.isdir(f"./{dataset_path}/{spks}"): + os.makedirs(f"./{output_path}/{spks}", exist_ok=True) + for file in os.listdir(f"./{dataset_path}/{spks}"): + if file.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}/{file}") + elif spks.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}") + return wav_files + + +def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder): + waveform = speaker_encoder_ap.load_wav( + wav_file, sr=speaker_encoder_ap.sample_rate + ) + spec = speaker_encoder_ap.melspectrogram(waveform) + spec = torch.from_numpy(spec.T) + if args.use_cuda: + spec = spec.cuda() + spec = spec.unsqueeze(0) + embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() + embed = embed.squeeze() + embed_path = wav_file.replace(dataset_path, output_path) + embed_path = embed_path.replace(".wav", ".spk") + np.save(embed_path, embed, allow_pickle=False) + + +def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency): + bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder) + + with ThreadPool(concurrency) as pool: + list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files))) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") + parser.add_argument( + "output_path", type=str, help="path for output speaker/speaker_wavs.npy." + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + args = parser.parse_args() + dataset_path = args.dataset_path + output_path = args.output_path + thread_count = args.thread_count + # model + args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar") + args.config_path = os.path.join("speaker_pretrain", "config.json") + # config + config_dict = read_json(args.config_path) + + # model + config = SpeakerEncoderConfig(config_dict) + config.from_dict(config_dict) + + speaker_encoder = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + ) + + speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) + + # preprocess + speaker_encoder_ap = AudioProcessor(**config.audio) + # normalize the input audio level and trim silences + speaker_encoder_ap.do_sound_norm = True + speaker_encoder_ap.do_trim_silence = True + + wav_files = get_spk_wavs(dataset_path, output_path) + + if thread_count == 0: + process_num = os.cpu_count() + else: + process_num = thread_count + + extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num) \ No newline at end of file diff --git a/prepare/preprocess_speaker_ave.py b/prepare/preprocess_speaker_ave.py new file mode 100644 index 0000000..9423f61 --- /dev/null +++ b/prepare/preprocess_speaker_ave.py @@ -0,0 +1,54 @@ +import os +import torch +import argparse +import numpy as np +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset_speaker", type=str) + parser.add_argument("dataset_singer", type=str) + + data_speaker = parser.parse_args().dataset_speaker + data_singer = parser.parse_args().dataset_singer + + os.makedirs(data_singer, exist_ok=True) + + for speaker in os.listdir(data_speaker): + subfile_num = 0 + speaker_ave = 0 + + for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"average {speaker}"): + if not file.endswith(".npy"): + continue + source_embed = np.load(os.path.join(data_speaker, speaker, file)) + source_embed = source_embed.astype(np.float32) + speaker_ave = speaker_ave + source_embed + subfile_num = subfile_num + 1 + if subfile_num == 0: + continue + speaker_ave = speaker_ave / subfile_num + + np.save(os.path.join(data_singer, f"{speaker}.spk.npy"), + speaker_ave, allow_pickle=False) + + # rewrite timbre code by average, if similarity is larger than cmp_val + rewrite_timbre_code = False + if not rewrite_timbre_code: + continue + cmp_src = torch.FloatTensor(speaker_ave) + cmp_num = 0 + cmp_val = 0.85 + for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"rewrite {speaker}"): + if not file.endswith(".npy"): + continue + cmp_tmp = np.load(os.path.join(data_speaker, speaker, file)) + cmp_tmp = cmp_tmp.astype(np.float32) + cmp_tmp = torch.FloatTensor(cmp_tmp) + cmp_cos = torch.cosine_similarity(cmp_src, cmp_tmp, dim=0) + if (cmp_cos > cmp_val): + cmp_num += 1 + np.save(os.path.join(data_speaker, speaker, file), + speaker_ave, allow_pickle=False) + print(f"rewrite timbre for {speaker} with :", cmp_num) diff --git a/prepare/preprocess_spec.py b/prepare/preprocess_spec.py new file mode 100644 index 0000000..8304f15 --- /dev/null +++ b/prepare/preprocess_spec.py @@ -0,0 +1,52 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import argparse +from concurrent.futures import ThreadPoolExecutor +from spec.inference import mel_spectrogram_file +from tqdm import tqdm +from omegaconf import OmegaConf + + +def compute_spec(hps, filename, specname): + spec = mel_spectrogram_file(filename, hps) + spec = torch.squeeze(spec, 0) + # print(spec.shape) + torch.save(spec, specname) + + +def process_file(file): + if file.endswith(".wav"): + file = file[:-4] + compute_spec(hps, f"{wavPath}/{spks}/{file}.wav", f"{spePath}/{spks}/{file}.mel.pt") + + +def process_files_with_thread_pool(wavPath, spks, thread_num): + files = os.listdir(f"./{wavPath}/{spks}") + with ThreadPoolExecutor(max_workers=thread_num) as executor: + list(tqdm(executor.map(process_file, files), total=len(files), desc=f'Processing spec {spks}')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-s", "--spe", help="spe", dest="spe", required=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + + args = parser.parse_args() + print(args.wav) + print(args.spe) + + os.makedirs(args.spe, exist_ok=True) + wavPath = args.wav + spePath = args.spe + hps = OmegaConf.load("./configs/base.yaml") + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{spePath}/{spks}", exist_ok=True) + if args.thread_count == 0: + process_num = os.cpu_count() // 2 + 1 + else: + process_num = args.thread_count + process_files_with_thread_pool(wavPath, spks, process_num) diff --git a/prepare/preprocess_train.py b/prepare/preprocess_train.py new file mode 100644 index 0000000..6410a0b --- /dev/null +++ b/prepare/preprocess_train.py @@ -0,0 +1,56 @@ +import os +import random + + +def print_error(info): + print(f"\033[31m File isn't existed: {info}\033[0m") + + +if __name__ == "__main__": + os.makedirs("./files/", exist_ok=True) + + rootPath = "./data_svc/waves-32k/" + all_items = [] + for spks in os.listdir(f"./{rootPath}"): + if not os.path.isdir(f"./{rootPath}/{spks}"): + continue + print(f"./{rootPath}/{spks}") + for file in os.listdir(f"./{rootPath}/{spks}"): + if file.endswith(".wav"): + file = file[:-4] + + path_mel = f"./data_svc/mel/{spks}/{file}.mel.pt" + path_vec = f"./data_svc/hubert/{spks}/{file}.vec.npy" + path_pit = f"./data_svc/pitch/{spks}/{file}.pit.npy" + path_spk = f"./data_svc/speaker/{spks}/{file}.spk.npy" + + has_error = 0 + if not os.path.isfile(path_mel): + print_error(path_mel) + has_error = 1 + if not os.path.isfile(path_vec): + print_error(path_vec) + has_error = 1 + if not os.path.isfile(path_pit): + print_error(path_pit) + has_error = 1 + if not os.path.isfile(path_spk): + print_error(path_spk) + has_error = 1 + if has_error == 0: + all_items.append( + f"{path_mel}|{path_vec}|{path_pit}|{path_spk}") + + random.shuffle(all_items) + valids = all_items[:10] + valids.sort() + trains = all_items[10:] + # trains.sort() + fw = open("./files/valid.txt", "w", encoding="utf-8") + for strs in valids: + print(strs, file=fw) + fw.close() + fw = open("./files/train.txt", "w", encoding="utf-8") + for strs in trains: + print(strs, file=fw) + fw.close() diff --git a/prepare/preprocess_zzz.py b/prepare/preprocess_zzz.py new file mode 100644 index 0000000..03659e2 --- /dev/null +++ b/prepare/preprocess_zzz.py @@ -0,0 +1,30 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from tqdm import tqdm +from torch.utils.data import DataLoader +from grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate + + +filelist_path = "files/valid.txt" + +dataset = TextMelSpeakerDataset(filelist_path) +collate = TextMelSpeakerBatchCollate() +loader = DataLoader(dataset=dataset, + batch_size=2, + collate_fn=collate, + drop_last=True, + num_workers=1, + shuffle=True) + +for batch in tqdm(loader): + lengths = batch['lengths'].cuda() + vec = batch['vec'].cuda() + pit = batch['pit'].cuda() + spk = batch['spk'].cuda() + mel = batch['mel'].cuda() + + print('len', lengths.shape) + print('vec', vec.shape) + print('pit', pit.shape) + print('spk', spk.shape) + print('mel', mel.shape) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..908feb5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +librosa +soundfile +matplotlib +tensorboard +transformers +tqdm +einops +fsspec +omegaconf +praat-parselmouth \ No newline at end of file diff --git a/speaker/__init__.py b/speaker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/speaker/config.py b/speaker/config.py new file mode 100644 index 0000000..7172ee2 --- /dev/null +++ b/speaker/config.py @@ -0,0 +1,64 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from .utils.coqpit import MISSING +from .utils.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class SpeakerEncoderConfig(BaseTrainingConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + storage: Dict = field( + default_factory=lambda: { + "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 15, # the size of the in-memory storage with respect to a single batch + } + ) + + # training params + max_train_step: int = 1000000 # end training when number of training steps reaches this value. + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + lr_decay: bool = False + warmup_steps: int = 4000 + wd: float = 1e-6 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + checkpoint: bool = True + save_step: int = 1000 + print_step: int = 20 + + # data loader + num_speakers_in_batch: int = MISSING + num_utters_per_speaker: int = MISSING + num_loader_workers: int = MISSING + skip_speakers: bool = False + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/speaker/infer.py b/speaker/infer.py new file mode 100644 index 0000000..b69b2ee --- /dev/null +++ b/speaker/infer.py @@ -0,0 +1,108 @@ +import re +import json +import fsspec +import torch +import numpy as np +import argparse + +from argparse import RawTextHelpFormatter +from .models.lstm import LSTMSpeakerEncoder +from .config import SpeakerEncoderConfig +from .utils.audio import AudioProcessor + + +def read_json(json_path): + config_dict = {} + try: + with fsspec.open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(json_path) + config_dict.update(data) + return config_dict + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) + data = json.loads(input_str) + return data + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument("-s", "--source", help="input wave", dest="source") + parser.add_argument( + "-t", "--target", help="output 256d speaker embeddimg", dest="target" + ) + + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + source_file = args.source + target_file = args.target + + # config + config_dict = read_json(args.config_path) + # print(config_dict) + + # model + config = SpeakerEncoderConfig(config_dict) + config.from_dict(config_dict) + + speaker_encoder = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + ) + + speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) + + # preprocess + speaker_encoder_ap = AudioProcessor(**config.audio) + # normalize the input audio level and trim silences + speaker_encoder_ap.do_sound_norm = True + speaker_encoder_ap.do_trim_silence = True + + # compute speaker embeddings + + # extract the embedding + waveform = speaker_encoder_ap.load_wav( + source_file, sr=speaker_encoder_ap.sample_rate + ) + spec = speaker_encoder_ap.melspectrogram(waveform) + spec = torch.from_numpy(spec.T) + if args.use_cuda: + spec = spec.cuda() + spec = spec.unsqueeze(0) + embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() + embed = embed.squeeze() + # print(embed) + # print(embed.size) + np.save(target_file, embed, allow_pickle=False) + + + if hasattr(speaker_encoder, 'module'): + state_dict = speaker_encoder.module.state_dict() + else: + state_dict = speaker_encoder.state_dict() + torch.save({'model': state_dict}, "model_small.pth") diff --git a/speaker/models/__init__.py b/speaker/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/speaker/models/lstm.py b/speaker/models/lstm.py new file mode 100644 index 0000000..45e8cce --- /dev/null +++ b/speaker/models/lstm.py @@ -0,0 +1,131 @@ +import numpy as np +import torch +from torch import nn + +from ..utils.io import load_fsspec + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(nn.Module): + def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x): + # TODO: implement state passing for lstms + d = self.layers(x) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d + + @torch.no_grad() + def inference(self, x): + d = self.layers.forward(x) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d + + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): + """ + Generate embeddings for a batch of utterances + x: BxTxD + """ + num_overlap = num_frames * overlap + max_len = x.shape[1] + embed = None + num_iters = seq_lens / (num_frames - num_overlap) + cur_iter = 0 + for offset in range(0, max_len, num_frames - num_overlap): + cur_iter += 1 + end_offset = min(x.shape[1], offset + num_frames) + frames = x[:, offset:end_offset] + if embed is None: + embed = self.inference(frames) + else: + embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) + return embed / num_iters + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training diff --git a/speaker/models/resnet.py b/speaker/models/resnet.py new file mode 100644 index 0000000..fcc850d --- /dev/null +++ b/speaker/models/resnet.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +from torch import nn + +from TTS.utils.io import load_fsspec + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(nn.Module): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + x = x.transpose(1, 2) + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.forward(frames_batch, l2_norm=True) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training diff --git a/speaker/utils/__init__.py b/speaker/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/speaker/utils/audio.py b/speaker/utils/audio.py new file mode 100644 index 0000000..e2c9627 --- /dev/null +++ b/speaker/utils/audio.py @@ -0,0 +1,822 @@ +from typing import Dict, Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy.io.wavfile +import scipy.signal +import soundfile as sf +import torch +from torch import nn + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + TODO: Merge this with audio.py + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain + + +# pylint: disable=too-many-public-methods +class AudioProcessor(object): + """Audio Processor for TTS used by all the data pipelines. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms.. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + stats_path=None, + verbose=True, + **_, + ): + + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.hop_length, self.win_length = self._stft_parameters() + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = self._build_mel_basis() + self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + ### setting up the parameters ### + def _build_mel_basis( + self, + ) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 + return librosa.filters.mel( + sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + + def _stft_parameters( + self, + ) -> Tuple[int, int]: + """Compute the real STFT parameters from the time values. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = self.frame_length_ms / self.frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + win_length = int(hop_length * factor) + return hop_length, win_length + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### DB and AMP conversion ### + # pylint: disable=no-self-use + def _amp_to_db(self, x: np.ndarray) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + + Returns: + np.ndarray: Decibels spectrogram. + """ + return self.spec_gain * _log(np.maximum(1e-5, x), self.base) + + # pylint: disable=no-self-use + def _db_to_amp(self, x: np.ndarray) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / self.spec_gain, self.base) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + + ### SPECTROGRAMs ### + def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: + """Project a full scale spectrogram to a melspectrogram. + + Args: + spectrogram (np.ndarray): Full scale spectrogram. + + Returns: + np.ndarray: Melspectrogram + """ + return np.dot(self.mel_basis, spectrogram) + + def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) + + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_linear: + S = self._amp_to_db(np.abs(D)) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_mel: + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) + else: + S = self._linear_to_mel(np.abs(D)) + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = self._db_to_amp(S) + # Reconstruct phase + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = self._db_to_amp(D) + S = self._mel_to_linear(S) # Convert back to linear + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = self._db_to_amp(S) + S = self._linear_to_mel(np.abs(S)) + S = self._amp_to_db(S) + mel = self.normalize(S) + return mel + + ### STFT and ISTFT ### + def _stft(self, y: np.ndarray) -> np.ndarray: + """Librosa STFT wrapper. + + Args: + y (np.ndarray): Audio signal. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + window="hann", + center=True, + ) + + def _istft(self, y: np.ndarray) -> np.ndarray: + """Librosa iSTFT wrapper.""" + return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) + + def _griffin_lim(self, S): + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = self._istft(S_complex * angles) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(self.griffin_lim_iters): + angles = np.exp(1j * np.angle(self._stft(y))) + y = self._istft(S_complex * angles) + return y + + def compute_stft_paddings(self, x, pad_sides=1): + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + assert pad_sides in (1, 2) + pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] + if pad_sides == 1: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.mel_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) + # f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # spec = self.melspectrogram(x) + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + window_length = int(self.sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = self._db_to_amp(threshold_db) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(self.sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * 0.95 + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if self.resample: + x, sr = librosa.load(filename, sr=self.sample_rate) + elif sr is None: + x, sr = sf.read(filename) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + else: + x, sr = librosa.load(filename, sr=sr) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) + + @staticmethod + def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: + mu = 2 ** qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + # Quantize signal to the specified number of levels. + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + @staticmethod + def mulaw_decode(wav, qc): + """Recovers waveform from quantized values.""" + mu = 2 ** qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + @staticmethod + def encode_16bits(x): + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + + @staticmethod + def quantize(x: np.ndarray, bits: int) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2 ** bits - 1) / 2 + + @staticmethod + def dequantize(x, bits): + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2 ** bits - 1) - 1 + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) diff --git a/speaker/utils/coqpit.py b/speaker/utils/coqpit.py new file mode 100644 index 0000000..e214c8b --- /dev/null +++ b/speaker/utils/coqpit.py @@ -0,0 +1,954 @@ +import argparse +import functools +import json +import operator +import os +from collections.abc import MutableMapping +from dataclasses import MISSING as _MISSING +from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace +from pathlib import Path +from pprint import pprint +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints + +T = TypeVar("T") +MISSING: Any = "???" + + +class _NoDefault(Generic[T]): + pass + + +NoDefaultVar = Union[_NoDefault[T], T] +no_default: NoDefaultVar = _NoDefault() + + +def is_primitive_type(arg_type: Any) -> bool: + """Check if the input type is one of `int, float, str, bool`. + + Args: + arg_type (typing.Any): input type to check. + + Returns: + bool: True if input type is one of `int, float, str, bool`. + """ + try: + return isinstance(arg_type(), (int, float, str, bool)) + except (AttributeError, TypeError): + return False + + +def is_list(arg_type: Any) -> bool: + """Check if the input type is `list` + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `list` + """ + try: + return arg_type is list or arg_type is List or arg_type.__origin__ is list or arg_type.__origin__ is List + except AttributeError: + return False + + +def is_dict(arg_type: Any) -> bool: + """Check if the input type is `dict` + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `dict` + """ + try: + return arg_type is dict or arg_type is Dict or arg_type.__origin__ is dict + except AttributeError: + return False + + +def is_union(arg_type: Any) -> bool: + """Check if the input type is `Union`. + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `Union` + """ + try: + return safe_issubclass(arg_type.__origin__, Union) + except AttributeError: + return False + + +def safe_issubclass(cls, classinfo) -> bool: + """Check if the input type is a subclass of the given class. + + Args: + cls (type): input type. + classinfo (type): parent class. + + Returns: + bool: True if the input type is a subclass of the given class + """ + try: + r = issubclass(cls, classinfo) + except Exception: # pylint: disable=broad-except + return cls is classinfo + else: + return r + + +def _coqpit_json_default(obj: Any) -> Any: + if isinstance(obj, Path): + return str(obj) + raise TypeError(f"Can't encode object of type {type(obj).__name__}") + + +def _default_value(x: Field): + """Return the default value of the input Field. + + Args: + x (Field): input Field. + + Returns: + object: default value of the input Field. + """ + if x.default not in (MISSING, _MISSING): + return x.default + if x.default_factory not in (MISSING, _MISSING): + return x.default_factory() + return x.default + + +def _is_optional_field(field) -> bool: + """Check if the input field is optional. + + Args: + field (Field): input Field to check. + + Returns: + bool: True if the input field is optional. + """ + # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, "__args__") + return type(None) in getattr(field.type, "__args__") + + +def my_get_type_hints( + cls, +): + """Custom `get_type_hints` dealing with https://github.com/python/typing/issues/737 + + Returns: + [dataclass]: dataclass to get the type hints of its fields. + """ + r_dict = {} + for base in cls.__class__.__bases__: + if base == object: + break + r_dict.update(my_get_type_hints(base)) + r_dict.update(get_type_hints(cls)) + return r_dict + + +def _serialize(x): + """Pick the right serialization for the datatype of the given input. + + Args: + x (object): input object. + + Returns: + object: serialized object. + """ + if isinstance(x, Path): + return str(x) + if isinstance(x, dict): + return {k: _serialize(v) for k, v in x.items()} + if isinstance(x, list): + return [_serialize(xi) for xi in x] + if isinstance(x, Serializable) or issubclass(type(x), Serializable): + return x.serialize() + if isinstance(x, type) and issubclass(x, Serializable): + return x.serialize(x) + return x + + +def _deserialize_dict(x: Dict) -> Dict: + """Deserialize dict. + + Args: + x (Dict): value to deserialized. + + Returns: + Dict: deserialized dictionary. + """ + out_dict = {} + for k, v in x.items(): + if v is None: # if {'key':None} + out_dict[k] = None + else: + out_dict[k] = _deserialize(v, type(v)) + return out_dict + + +def _deserialize_list(x: List, field_type: Type) -> List: + """Deserialize values for List typed fields. + + Args: + x (List): value to be deserialized + field_type (Type): field type. + + Raises: + ValueError: Coqpit does not support multi type-hinted lists. + + Returns: + [List]: deserialized list. + """ + field_args = None + if hasattr(field_type, "__args__") and field_type.__args__: + field_args = field_type.__args__ + elif hasattr(field_type, "__parameters__") and field_type.__parameters__: + # bandaid for python 3.6 + field_args = field_type.__parameters__ + if field_args: + if len(field_args) > 1: + raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'") + field_arg = field_args[0] + # if field type is TypeVar set the current type by the value's type. + if isinstance(field_arg, TypeVar): + field_arg = type(x) + return [_deserialize(xi, field_arg) for xi in x] + return x + + +def _deserialize_union(x: Any, field_type: Type) -> Any: + """Deserialize values for Union typed fields + + Args: + x (Any): value to be deserialized. + field_type (Type): field type. + + Returns: + [Any]: desrialized value. + """ + for arg in field_type.__args__: + # stop after first matching type in Union + try: + x = _deserialize(x, arg) + break + except ValueError: + pass + return x + + +def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: Type) -> Union[int, float, str, bool]: + """Deserialize python primitive types (float, int, str, bool). + It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values. + + Args: + x (Union[int, float, str, bool]): value to be deserialized. + field_type (Type): field type. + + Returns: + Union[int, float, str, bool]: deserialized value. + """ + + if isinstance(x, (str, bool)): + return x + if isinstance(x, (int, float)): + if x == float("inf") or x == float("-inf"): + # if value type is inf return regardless. + return x + x = field_type(x) + return x + # TODO: Raise an error when x does not match the types. + return None + + +def _deserialize(x: Any, field_type: Any) -> Any: + """Pick the right desrialization for the given object and the corresponding field type. + + Args: + x (object): object to be deserialized. + field_type (type): expected type after deserialization. + + Returns: + object: deserialized object + + """ + # pylint: disable=too-many-return-statements + if is_dict(field_type): + return _deserialize_dict(x) + if is_list(field_type): + return _deserialize_list(x, field_type) + if is_union(field_type): + return _deserialize_union(x, field_type) + if issubclass(field_type, Serializable): + return field_type.deserialize_immutable(x) + if is_primitive_type(field_type): + return _deserialize_primitive_types(x, field_type) + raise ValueError(f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type.") + + +# Recursive setattr (supports dotted attr names) +def rsetattr(obj, attr, val): + def _setitem(obj, attr, val): + return operator.setitem(obj, int(attr), val) + + pre, _, post = attr.rpartition(".") + setfunc = _setitem if post.isnumeric() else setattr + + return setfunc(rgetattr(obj, pre) if pre else obj, post, val) + + +# Recursive getattr (supports dotted attr names) +def rgetattr(obj, attr, *args): + def _getitem(obj, attr): + return operator.getitem(obj, int(attr), *args) + + def _getattr(obj, attr): + getfunc = _getitem if attr.isnumeric() else getattr + return getfunc(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +# Recursive setitem (supports dotted attr names) +def rsetitem(obj, attr, val): + pre, _, post = attr.rpartition(".") + return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) + + +# Recursive getitem (supports dotted attr names) +def rgetitem(obj, attr, *args): + def _getitem(obj, attr): + return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) + + return functools.reduce(_getitem, [obj] + attr.split(".")) + + +@dataclass +class Serializable: + """Gives serialization ability to any inheriting dataclass.""" + + def __post_init__(self): + self._validate_contracts() + for key, value in self.__dict__.items(): + if value is no_default: + raise TypeError(f"__init__ missing 1 required argument: '{key}'") + + def _validate_contracts(self): + dataclass_fields = fields(self) + + for field in dataclass_fields: + + value = getattr(self, field.name) + + if value is None: + if not _is_optional_field(field): + raise TypeError(f"{field.name} is not optional") + + contract = field.metadata.get("contract", None) + + if contract is not None: + if value is not None and not contract(value): + raise ValueError(f"break the contract for {field.name}, {self.__class__.__name__}") + + def validate(self): + """validate if object can serialize / deserialize correctly.""" + self._validate_contracts() + if self != self.__class__.deserialize( # pylint: disable=no-value-for-parameter + json.loads(json.dumps(self.serialize())) + ): + raise ValueError("could not be deserialized with same value") + + def to_dict(self) -> dict: + """Transform serializable object to dict.""" + cls_fields = fields(self) + o = {} + for cls_field in cls_fields: + o[cls_field.name] = getattr(self, cls_field.name) + return o + + def serialize(self) -> dict: + """Serialize object to be json serializable representation.""" + if not is_dataclass(self): + raise TypeError("need to be decorated as dataclass") + + dataclass_fields = fields(self) + + o = {} + + for field in dataclass_fields: + value = getattr(self, field.name) + value = _serialize(value) + o[field.name] = value + return o + + def deserialize(self, data: dict) -> "Serializable": + """Parse input dictionary and desrialize its fields to a dataclass. + + Returns: + self: deserialized `self`. + """ + if not isinstance(data, dict): + raise ValueError() + data = data.copy() + init_kwargs = {} + for field in fields(self): + # if field.name == 'dataset_config': + if field.name not in data: + if field.name in vars(self): + init_kwargs[field.name] = vars(self)[field.name] + continue + raise ValueError(f' [!] Missing required field "{field.name}"') + value = data.get(field.name, _default_value(field)) + if value is None: + init_kwargs[field.name] = value + continue + if value == MISSING: + raise ValueError(f"deserialized with unknown value for {field.name} in {self.__name__}") + value = _deserialize(value, field.type) + init_kwargs[field.name] = value + for k, v in init_kwargs.items(): + setattr(self, k, v) + return self + + @classmethod + def deserialize_immutable(cls, data: dict) -> "Serializable": + """Parse input dictionary and desrialize its fields to a dataclass. + + Returns: + Newly created deserialized object. + """ + if not isinstance(data, dict): + raise ValueError() + data = data.copy() + init_kwargs = {} + for field in fields(cls): + # if field.name == 'dataset_config': + if field.name not in data: + if field.name in vars(cls): + init_kwargs[field.name] = vars(cls)[field.name] + continue + # if not in cls and the default value is not Missing use it + default_value = _default_value(field) + if default_value not in (MISSING, _MISSING): + init_kwargs[field.name] = default_value + continue + raise ValueError(f' [!] Missing required field "{field.name}"') + value = data.get(field.name, _default_value(field)) + if value is None: + init_kwargs[field.name] = value + continue + if value == MISSING: + raise ValueError(f"Deserialized with unknown value for {field.name} in {cls.__name__}") + value = _deserialize(value, field.type) + init_kwargs[field.name] = value + return cls(**init_kwargs) + + +# ---------------------------------------------------------------------------- # +# Argument Parsing from `argparse` # +# ---------------------------------------------------------------------------- # + + +def _get_help(field): + try: + field_help = field.metadata["help"] + except KeyError: + field_help = "" + return field_help + + +def _init_argparse( + parser, + field_name, + field_type, + field_default, + field_default_factory, + field_help, + arg_prefix="", + help_prefix="", + relaxed_parser=False, +): + has_default = False + default = None + if field_default: + has_default = True + default = field_default + elif field_default_factory not in (None, _MISSING): + has_default = True + default = field_default_factory() + + if not has_default and not is_primitive_type(field_type) and not is_list(field_type): + # aggregate types (fields with a Coqpit subclass as type) are not supported without None + return parser + arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" + help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}" + if is_dict(field_type): # pylint: disable=no-else-raise + # NOTE: accept any string in json format as input to dict field. + parser.add_argument( + f"--{arg_prefix}", + dest=arg_prefix, + default=json.dumps(field_default) if field_default else None, + type=json.loads, + ) + elif is_list(field_type): + # TODO: We need a more clear help msg for lists. + if hasattr(field_type, "__args__"): # if the list is hinted + if len(field_type.__args__) > 1 and not relaxed_parser: + raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'") + list_field_type = field_type.__args__[0] + else: + raise ValueError(" [!] Coqpit does not support un-hinted 'List'") + + # TODO: handle list of lists + if is_list(list_field_type) and relaxed_parser: + return parser + + if not has_default or field_default_factory is list: + if not is_primitive_type(list_field_type) and not relaxed_parser: + raise NotImplementedError(" [!] Empty list with non primitive inner type is currently not supported.") + + # If the list's default value is None, the user can specify the entire list by passing multiple parameters + parser.add_argument( + f"--{arg_prefix}", + nargs="*", + type=list_field_type, + help=f"Coqpit Field: {help_prefix}", + ) + else: + # If a default value is defined, just enable editing the values from argparse + # TODO: allow inserting a new value/obj to the end of the list. + for idx, fv in enumerate(default): + parser = _init_argparse( + parser, + str(idx), + list_field_type, + fv, + field_default_factory, + field_help="", + help_prefix=f"{help_prefix} - ", + arg_prefix=f"{arg_prefix}", + relaxed_parser=relaxed_parser, + ) + elif is_union(field_type): + # TODO: currently I don't know how to handle Union type on argparse + if not relaxed_parser: + raise NotImplementedError( + " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." + ) + elif issubclass(field_type, Serializable): + return default.init_argparse( + parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser + ) + elif isinstance(field_type(), bool): + + def parse_bool(x): + if x not in ("true", "false"): + raise ValueError(f' [!] Value for boolean field must be either "true" or "false". Got "{x}".') + return x == "true" + + parser.add_argument( + f"--{arg_prefix}", + type=parse_bool, + default=field_default, + help=f"Coqpit Field: {help_prefix}", + metavar="true/false", + ) + elif is_primitive_type(field_type): + parser.add_argument( + f"--{arg_prefix}", + default=field_default, + type=field_type, + help=f"Coqpit Field: {help_prefix}", + ) + else: + if not relaxed_parser: + raise NotImplementedError(f" [!] '{field_type}' is not supported by arg_parser. Please file a bug report.") + return parser + + +# ---------------------------------------------------------------------------- # +# Main Coqpit Class # +# ---------------------------------------------------------------------------- # + + +@dataclass +class Coqpit(Serializable, MutableMapping): + """Coqpit base class to be inherited by any Coqpit dataclasses. + It overrides Python `dict` interface and provides `dict` compatible API. + It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check. + Note that it does not support all datatypes and likely to fail in some cases. + """ + + _initialized = False + + def _is_initialized(self): + """Check if Coqpit is initialized. Useful to prevent running some aux functions + at the initialization when no attribute has been defined.""" + return "_initialized" in vars(self) and self._initialized + + def __post_init__(self): + self._initialized = True + try: + self.check_values() + except AttributeError: + pass + + ## `dict` API functions + + def __iter__(self): + return iter(asdict(self)) + + def __len__(self): + return len(fields(self)) + + def __setitem__(self, arg: str, value: Any): + setattr(self, arg, value) + + def __getitem__(self, arg: str): + """Access class attributes with ``[arg]``.""" + return self.__dict__[arg] + + def __delitem__(self, arg: str): + delattr(self, arg) + + def _keytransform(self, key): # pylint: disable=no-self-use + return key + + ## end `dict` API functions + + def __getattribute__(self, arg: str): # pylint: disable=no-self-use + """Check if the mandatory field is defined when accessing it.""" + value = super().__getattribute__(arg) + if isinstance(value, str) and value == "???": + raise AttributeError(f" [!] MISSING field {arg} must be defined.") + return value + + def __contains__(self, arg: str): + return arg in self.to_dict() + + def get(self, key: str, default: Any = None): + if self.has(key): + return asdict(self)[key] + return default + + def items(self): + return asdict(self).items() + + def merge(self, coqpits: Union["Coqpit", List["Coqpit"]]): + """Merge a coqpit instance or a list of coqpit instances to self. + Note that it does not pass the fields and overrides attributes with + the last Coqpit instance in the given List. + TODO: find a way to merge instances with all the class internals. + + Args: + coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged. + """ + + def _merge(coqpit): + self.__dict__.update(coqpit.__dict__) + self.__annotations__.update(coqpit.__annotations__) + self.__dataclass_fields__.update(coqpit.__dataclass_fields__) + + if isinstance(coqpits, list): + for coqpit in coqpits: + _merge(coqpit) + else: + _merge(coqpits) + + def check_values(self): + pass + + def has(self, arg: str) -> bool: + return arg in vars(self) + + def copy(self): + return replace(self) + + def update(self, new: dict, allow_new=False) -> None: + """Update Coqpit fields by the input ```dict```. + + Args: + new (dict): dictionary with new values. + allow_new (bool, optional): allow new fields to add. Defaults to False. + """ + for key, value in new.items(): + if allow_new: + setattr(self, key, value) + else: + if hasattr(self, key): + setattr(self, key, value) + else: + raise KeyError(f" [!] No key - {key}") + + def pprint(self) -> None: + """Print Coqpit fields in a format.""" + pprint(asdict(self)) + + def to_dict(self) -> dict: + # return asdict(self) + return self.serialize() + + def from_dict(self, data: dict) -> None: + self = self.deserialize(data) # pylint: disable=self-cls-assignment + + @classmethod + def new_from_dict(cls: Serializable, data: dict) -> "Coqpit": + return cls.deserialize_immutable(data) + + def to_json(self) -> str: + """Returns a JSON string representation.""" + return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) + + def save_json(self, file_name: str) -> None: + """Save Coqpit to a json file. + + Args: + file_name (str): path to the output json file. + """ + with open(file_name, "w", encoding="utf8") as f: + json.dump(asdict(self), f, indent=4) + + def load_json(self, file_name: str) -> None: + """Load a json file and update matching config fields with type checking. + Non-matching parameters in the json file are ignored. + + Args: + file_name (str): path to the json file. + + Returns: + Coqpit: new Coqpit with updated config fields. + """ + with open(file_name, "r", encoding="utf8") as f: + input_str = f.read() + dump_dict = json.loads(input_str) + # TODO: this looks stupid 💆 + self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment + self.check_values() + + @classmethod + def init_from_argparse( + cls, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit" + ) -> "Coqpit": + """Create a new Coqpit instance from argparse input. + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = cls.init_argparse(cls, arg_prefix=arg_prefix) + args = parser.parse_args() # pylint: disable=E1120, E1111 + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = cls.init_argparse(cls, arg_prefix=arg_prefix) + args = parser.parse_args(args) # pylint: disable=E1120, E1111 + + # Handle list and object attributes with defaults, which can be modified + # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects + # from defaults and passing those to `cls.__init__` + args_with_lists_processed = {} + class_fields = fields(cls) + for field in class_fields: + has_default = False + default = None + field_default = field.default if field.default is not _MISSING else None + field_default_factory = field.default_factory if field.default_factory is not _MISSING else None + if field_default: + has_default = True + default = field_default + elif field_default_factory: + has_default = True + default = field_default_factory() + + if has_default and (not is_primitive_type(field.type) or is_list(field.type)): + args_with_lists_processed[field.name] = default + + args_dict = vars(args) + for k, v in args_dict.items(): + # Remove argparse prefix (eg. "--coqpit." if present) + if k.startswith(f"{arg_prefix}."): + k = k[len(f"{arg_prefix}.") :] + + rsetitem(args_with_lists_processed, k, v) + + return cls(**args_with_lists_processed) + + def parse_args( + self, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit" + ) -> None: + """Update config values from argparse arguments with some meta-programming ✨. + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = self.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args() + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = self.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args(args) + + args_dict = vars(args) + + for k, v in args_dict.items(): + if k.startswith(f"{arg_prefix}."): + k = k[len(f"{arg_prefix}.") :] + try: + rgetattr(self, k) + except (TypeError, AttributeError) as e: + raise Exception(f" [!] '{k}' not exist to override from argparse.") from e + + rsetattr(self, k, v) + + self.check_values() + + def parse_known_args( + self, + args: Optional[Union[argparse.Namespace, List[str]]] = None, + arg_prefix: str = "coqpit", + relaxed_parser=False, + ) -> List[str]: + """Update config values from argparse arguments. Ignore unknown arguments. + This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args). + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + + Returns: + List of unknown parameters. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + args, unknown = parser.parse_known_args() + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + args, unknown = parser.parse_known_args(args) + + self.parse_args(args) + return unknown + + def init_argparse( + self, + parser: Optional[argparse.ArgumentParser] = None, + arg_prefix="coqpit", + help_prefix="", + relaxed_parser=False, + ) -> argparse.ArgumentParser: + """Pass Coqpit fields as argparse arguments. This allows to edit values through command-line. + + Args: + parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. + arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'. + help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''. + relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + + Returns: + argparse.ArgumentParser: parser instance with the new arguments. + """ + if not parser: + parser = argparse.ArgumentParser() + class_fields = fields(self) + for field in class_fields: + if field.name in vars(self): + # use the current value of the field + # prevent dropping the current value + field_default = vars(self)[field.name] + else: + # use the default value of the field + field_default = field.default if field.default is not _MISSING else None + field_type = field.type + field_default_factory = field.default_factory + field_help = _get_help(field) + _init_argparse( + parser, + field.name, + field_type, + field_default, + field_default_factory, + field_help, + arg_prefix, + help_prefix, + relaxed_parser, + ) + return parser + + +def check_argument( + name, + c, + is_path: bool = False, + prerequest: str = None, + enum_list: list = None, + max_val: float = None, + min_val: float = None, + restricted: bool = False, + alternative: str = None, + allow_none: bool = True, +) -> None: + """Simple type and value checking for Coqpit. + It is intended to be used under ```__post_init__()``` of config dataclasses. + + Args: + name (str): name of the field to be checked. + c (dict): config dictionary. + is_path (bool, optional): if ```True``` check if the path is exist. Defaults to False. + prerequest (list or str, optional): a list of field name that are prerequestedby the target field name. + Defaults to ```[]```. + enum_list (list, optional): list of possible values for the target field. Defaults to None. + max_val (float, optional): maximum possible value for the target field. Defaults to None. + min_val (float, optional): minimum possible value for the target field. Defaults to None. + restricted (bool, optional): if ```True``` the target field has to be defined. Defaults to False. + alternative (str, optional): a field name superceding the target field. Defaults to None. + allow_none (bool, optional): if ```True``` allow the target field to be ```None```. Defaults to False. + + + Example: + >>> num_mels = 5 + >>> check_argument('num_mels', c, restricted=True, min_val=10, max_val=2056) + >>> fft_size = 128 + >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058) + """ + # check if None allowed + if allow_none and c[name] is None: + return + if not allow_none: + assert c[name] is not None, f" [!] None value is not allowed for {name}." + # check if restricted and it it is check if it exists + if isinstance(restricted, bool) and restricted: + assert name in c.keys(), f" [!] {name} not defined in config.json" + # check prerequest fields are defined + if isinstance(prerequest, list): + assert any( + f not in c.keys() for f in prerequest + ), f" [!] prequested fields {prerequest} for {name} are not defined." + else: + assert ( + prerequest is None or prerequest in c.keys() + ), f" [!] prequested fields {prerequest} for {name} are not defined." + # check if the path exists + if is_path: + assert os.path.exists(c[name]), f' [!] path for {name} ("{c[name]}") does not exist.' + # skip the rest if the alternative field is defined. + if alternative in c.keys() and c[alternative] is not None: + return + # check value constraints + if name in c.keys(): + if max_val is not None: + assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" + if min_val is not None: + assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" + if enum_list is not None: + assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" diff --git a/speaker/utils/io.py b/speaker/utils/io.py new file mode 100644 index 0000000..1d4c079 --- /dev/null +++ b/speaker/utils/io.py @@ -0,0 +1,198 @@ +import datetime +import json +import os +import pickle as pickle_tts +import shutil +from typing import Any, Callable, Dict, Union + +import fsspec +import torch +from .coqpit import Coqpit + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def copy_model_files(config: Coqpit, out_path, new_fields): + """Copy config.json and other model files to training folder and add + new fields. + + Args: + config (Coqpit): Coqpit config defining the training run. + out_path (str): output path to copy the file. + new_fields (dict): new fileds to be added or edited + in the config file. + """ + copy_config_path = os.path.join(out_path, "config.json") + # add extra information fields + config.update(new_fields, allow_new=True) + # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. + with fsspec.open(copy_config_path, "w", encoding="utf8") as f: + json.dump(config.to_dict(), f, indent=4) + + # copy model stats file if available + if config.audio.stats_path is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") + filesystem = fsspec.get_mapper(copy_stats_path).fs + if not filesystem.exists(copy_stats_path): + with fsspec.open(config.audio.stats_path, "rb") as source_file: + with fsspec.open(copy_stats_path, "wb") as target_file: + shutil.copyfileobj(source_file, target_file) + + +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state + + +def save_fsspec(state: Any, path: str, **kwargs): + """Like torch.save but can save to other locations (e.g. s3:// , gs://). + + Args: + state: State object to save + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.save. + """ + with fsspec.open(path, "wb") as f: + torch.save(state, f, **kwargs) + + +def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): + if hasattr(model, "module"): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + if isinstance(optimizer, list): + optimizer_state = [optim.state_dict() for optim in optimizer] + else: + optimizer_state = optimizer.state_dict() if optimizer is not None else None + + if isinstance(scaler, list): + scaler_state = [s.state_dict() for s in scaler] + else: + scaler_state = scaler.state_dict() if scaler is not None else None + + if isinstance(config, Coqpit): + config = config.to_dict() + + state = { + "config": config, + "model": model_state, + "optimizer": optimizer_state, + "scaler": scaler_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + save_fsspec(state, output_path) + + +def save_checkpoint( + config, + model, + optimizer, + scaler, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(output_folder, file_name) + print("\n > CHECKPOINT : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) + + +def save_best_model( + current_loss, + best_loss, + config, + model, + optimizer, + scaler, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): + if current_loss < best_loss: + best_model_name = f"best_model_{current_step}.pth.tar" + checkpoint_path = os.path.join(out_path, best_model_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) + fs = fsspec.get_mapper(out_path).fs + # only delete previous if current is saved successfully + if not keep_all_best or (current_step < keep_after): + model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar")) + for model_name in model_names: + if os.path.basename(model_name) != best_model_name: + fs.rm(model_name) + # create a shortcut which always points to the currently best model + shortcut_name = "best_model.pth.tar" + shortcut_path = os.path.join(out_path, shortcut_name) + fs.copy(checkpoint_path, shortcut_path) + best_loss = current_loss + return best_loss diff --git a/speaker/utils/shared_configs.py b/speaker/utils/shared_configs.py new file mode 100644 index 0000000..a89d3a9 --- /dev/null +++ b/speaker/utils/shared_configs.py @@ -0,0 +1,342 @@ +from dataclasses import asdict, dataclass +from typing import List + +from .coqpit import Coqpit, check_argument + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + name (str): + Dataset name that defines the preprocessor in use. Defaults to None. + + path (str): + Root path to the dataset files. Defaults to None. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to None. + + unused_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + name: str = "" + path: str = "" + meta_file_train: str = "" + ununsed_speakers: List[str] = None + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("name", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(Coqpit): + """Base config to define the basic training parameters that are shared + among all the models. + + Args: + model (str): + Name of the model that is used in the training. + + run_name (str): + Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. + + run_description (str): + Short description of the experiment. + + epochs (int): + Number training epochs. Defaults to 10000. + + batch_size (int): + Training batch size. + + eval_batch_size (int): + Validation batch size. + + mixed_precision (bool): + Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however + it may also cause numerical unstability in some cases. + + scheduler_after_epoch (bool): + If true, run the scheduler step after each epoch else run it after each model step. + + run_eval (bool): + Enable / Disable evaluation (validation) run. Defaults to True. + + test_delay_epochs (int): + Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful + results, hence waiting for a couple of epochs might save some time. + + print_eval (bool): + Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at + the end of the evaluation. Default to ```False```. + + print_step (int): + Number of steps required to print the next training log. + + log_dashboard (str): "tensorboard" or "wandb" + Set the experiment tracking tool + + plot_step (int): + Number of steps required to log training on Tensorboard. + + model_param_stats (bool): + Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. + Defaults to ```False```. + + project_name (str): + Name of the project. Defaults to config.model + + wandb_entity (str): + Name of W&B entity/team. Enables collaboration across a team or org. + + log_model_step (int): + Number of steps required to log a checkpoint as W&B artifact + + save_step (int):ipt + Number of steps required to save the next checkpoint. + + checkpoint (bool): + Enable / Disable checkpointing. + + keep_all_best (bool): + Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults + to ```False```. + + keep_after (int): + Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults + to 10000. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + + output_path (str): + Path for training output folder, either a local file path or other + URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or + S3 (s3://) paths. The nonexist part of the given path is created + automatically. All training artefacts are saved there. + """ + + model: str = None + run_name: str = "coqui_tts" + run_description: str = "" + # training params + epochs: int = 10000 + batch_size: int = None + eval_batch_size: int = None + mixed_precision: bool = False + scheduler_after_epoch: bool = False + # eval params + run_eval: bool = True + test_delay_epochs: int = 0 + print_eval: bool = False + # logging + dashboard_logger: str = "tensorboard" + print_step: int = 25 + plot_step: int = 100 + model_param_stats: bool = False + project_name: str = None + log_model_step: int = None + wandb_entity: str = None + # checkpointing + save_step: int = 10000 + checkpoint: bool = True + keep_all_best: bool = False + keep_after: int = 10000 + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False + # paths + output_path: str = None + # distributed + distributed_backend: str = "nccl" + distributed_url: str = "tcp://localhost:54321" diff --git a/speaker_pretrain/README.md b/speaker_pretrain/README.md new file mode 100644 index 0000000..1cc5960 --- /dev/null +++ b/speaker_pretrain/README.md @@ -0,0 +1,5 @@ +Path for: + + best_model.pth.tar + + config.json \ No newline at end of file diff --git a/speaker_pretrain/config.json b/speaker_pretrain/config.json new file mode 100644 index 0000000..e330aab --- /dev/null +++ b/speaker_pretrain/config.json @@ -0,0 +1,104 @@ +{ + "model_name": "lstm", + "run_name": "mueller91", + "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", + "audio":{ + // Audio processing parameters + "num_mels": 80, // size of the mel spec frame. + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. + }, + "reinit_layers": [], + "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) + "grad_clip": 3.0, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "steps_plot_stats": 10, // number of steps to plot embeddings. + "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "voice_len": 2.0, // size of the voice + "num_utters_per_speaker": 10, // + "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 20, // Number of steps to log traning on console. + "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. + "model": { + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": true + }, + "storage": { + "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 25, // the size of the in-memory storage with respect to a single batch + "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness + }, + "datasets": + [ + { + "name": "vctk_slim", + "path": "../../../audio-datasets/en/VCTK-Corpus/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "libri_tts", + "path": "../../../audio-datasets/en/LibriTTS/train-other-500", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb1", + "path": "../../../audio-datasets/en/voxceleb1/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "voxceleb2", + "path": "../../../audio-datasets/en/voxceleb2/", + "meta_file_train": null, + "meta_file_val": null + }, + { + "name": "common_voice", + "path": "../../../audio-datasets/en/MozillaCommonVoice", + "meta_file_train": "train.tsv", + "meta_file_val": "test.tsv" + } + ] +} \ No newline at end of file diff --git a/spec/inference.py b/spec/inference.py new file mode 100644 index 0000000..6cc4042 --- /dev/null +++ b/spec/inference.py @@ -0,0 +1,113 @@ +import argparse +import torch +import torch.utils.data +import numpy as np +import librosa +from omegaconf import OmegaConf +from librosa.filters import mel as librosa_mel_fn + + +MAX_WAV_VALUE = 32768.0 + + +def load_wav_to_torch(full_path, sample_rate): + wav, _ = librosa.load(full_path, sr=sample_rate) + wav = wav / np.abs(wav).max() * 0.6 + return torch.FloatTensor(wav) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + + spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def mel_spectrogram_file(path, hps): + audio = load_wav_to_torch(path, hps.data.sampling_rate) + audio = audio.unsqueeze(0) + + # match audio length to self.hop_length * n for evaluation + if (audio.size(1) % hps.data.hop_length) != 0: + audio = audio[:, :-(audio.size(1) % hps.data.hop_length)] + mel = mel_spectrogram(audio, hps.data.filter_length, hps.data.mel_channels, hps.data.sampling_rate, + hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, hps.data.mel_fmax, center=False) + return mel + + +def print_mel(mel, path="mel.png"): + import matplotlib.pyplot as plt + fig = plt.figure(figsize=(12, 4)) + if isinstance(mel, torch.Tensor): + mel = mel.cpu().numpy() + plt.pcolor(mel) + plt.savefig(path, format="png") + plt.close(fig) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav") + parser.add_argument("-m", "--mel", help="mel", dest="mel") # csv for excel + args = parser.parse_args() + print(args.wav) + print(args.mel) + + hps = OmegaConf.load(f"./configs/base.yaml") + + mel = mel_spectrogram_file(args.wav, hps) + # TODO + mel = torch.squeeze(mel, 0) + # [100, length] + torch.save(mel, args.mel) + print_mel(mel, "debug.mel.png")