forked from atomicoo/FCH-TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsynthesize.wave.py
121 lines (101 loc) · 4.55 KB
/
synthesize.wave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : train-duration.py
@Date : 2021/01/05, Tue
@Author : Atomicoo
@Version : 1.0
@Contact : [email protected]
@License : (C)Copyright 2020-2021, ShiGroup-NLP-XMU
@Desc : Synthetize sentences into speech.
'''
__author__ = 'Atomicoo'
import argparse
import os
import os.path as osp
import time
from scipy.io.wavfile import write
import torch
from utils.hparams import HParam
from utils.transform import StandardNorm
from helpers.synthesizer import Synthesizer
import vocoder.models
from vocoder.layers import PQMF
from utils.audio import dynamic_range_decompression
from datasets.dataset import TextProcessor
from models import ParallelText2Mel
from utils.utils import select_device, get_last_chkpt_path
try:
from helpers.manager import GPUManager
except ImportError as err:
print(err); gm = None
else:
gm = GPUManager()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", default=8, type=int, help="Batch size")
parser.add_argument("--checkpoint", default=None, type=str, help="Checkpoint file path")
parser.add_argument("--melgan_checkpoint", default=None, type=str, help="Checkpoint file path of melgan")
parser.add_argument("--input_texts", default=None, type=str, help="Input text file path")
parser.add_argument("--outputs_dir", default=None, type=str, help="Output wave file directory")
parser.add_argument("--device", default=None, help="cuda device or cpu")
parser.add_argument("--name", default="parallel", type=str, help="Append to logdir name")
parser.add_argument("--config", default=None, type=str, help="Config file path")
args = parser.parse_args()
if torch.cuda.is_available():
index = args.device if args.device else str(0 if gm is None else gm.auto_choice())
else:
index = 'cpu'
device = select_device(index)
hparams = HParam(args.config) \
if args.config else HParam(osp.join(osp.abspath(os.getcwd()), "config", "default.yaml"))
logdir = osp.join(hparams.trainer.logdir, f"%s-%s" % (hparams.data.dataset, args.name))
checkpoint = args.checkpoint or get_last_chkpt_path(logdir)
normalizer = StandardNorm(hparams.audio.spec_mean, hparams.audio.spec_std)
processor = TextProcessor(hparams.text)
text2mel = ParallelText2Mel(hparams.parallel)
text2mel.eval()
synthesizer = Synthesizer(
model=text2mel,
checkpoint=checkpoint,
processor=processor,
normalizer=normalizer,
device=device
)
print('Synthesizing...')
since = time.time()
text_file = args.input_texts or hparams.synthesizer.inputs_file_path
with open(text_file, 'r', encoding='utf-8') as fr:
texts = fr.read().strip().split('\n')
melspecs = synthesizer.inference(texts)
print(f"Inference {len(texts)} spectrograms, total elapsed {time.time()-since:.3f}s. Done.")
vocoder_checkpoint = args.melgan_checkpoint or \
osp.join(hparams.trainer.logdir, f"{hparams.data.dataset}-melgan", hparams.melgan.checkpoint)
ckpt = torch.load(vocoder_checkpoint, map_location=device)
# Ref: https://github.com/kan-bayashi/ParallelWaveGAN/issues/169
decompressed = dynamic_range_decompression(melspecs)
decompressed_log10 = torch.log10(decompressed)
mu = torch.tensor(ckpt['stats']['mu']).to(device).unsqueeze(1)
var = torch.tensor(ckpt['stats']['var']).to(device).unsqueeze(1)
sigma = torch.sqrt(var)
melspecs = (decompressed_log10 - mu) / sigma
Generator = getattr(vocoder.models, ckpt['gtype'])
vocoder = Generator(**ckpt['config']).to(device)
vocoder.remove_weight_norm()
if ckpt['config']['out_channels'] > 1:
vocoder.pqmf = PQMF().to(device)
vocoder.load_state_dict(ckpt['model'])
if ckpt['config']['out_channels'] > 1:
waves = vocoder.pqmf.synthesis(vocoder(melspecs)).squeeze(1)
else:
waves = vocoder(melspecs).squeeze(1)
print(f"Generate {len(texts)} audios, total elapsed {time.time()-since:.3f}s. Done.")
print('Saving audio...')
outputs_dir = args.outputs_dir or hparams.synthesizer.outputs_dir
os.makedirs(outputs_dir, exist_ok=True)
for i, wav in enumerate(waves, start=1):
wav = wav.cpu().detach().numpy()
filename = osp.join(outputs_dir, f"{time.strftime('%Y-%m-%d')}_{i:03d}.wav")
write(filename, hparams.audio.sampling_rate, wav)
print(f"Audios saved to {outputs_dir}. Done.")
print(f'Done. ({time.time()-since:.3f}s)')