Skip to content

Commit

Permalink
Fix: support infer 2.2 models (fishaudio#244)
Browse files Browse the repository at this point in the history
* Fix: support infer 2.2 models

* Fix imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
litagin02 and pre-commit-ci[bot] authored Dec 21, 2023
1 parent 7ebc1aa commit 98f5917
Show file tree
Hide file tree
Showing 30 changed files with 133,753 additions and 78 deletions.
1 change: 0 additions & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.utils.data
from tqdm import tqdm
import numpy as np
from tools.log import logger
import commons
from mel_processing import spectrogram_torch, mel_spectrogram_torch
Expand Down
3 changes: 1 addition & 2 deletions for_deploy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
"""
import torch
import commons
from text import cleaned_text_to_sequence, get_bert
from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
import utils
import numpy as np
Expand Down
32 changes: 29 additions & 3 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
特殊版本说明:
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
2.2:当前版本
2.3:当前版本
"""
import torch
import commons
Expand All @@ -14,11 +14,12 @@
# from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
from text.cleaner import clean_text
import utils
import numpy as np

from models import SynthesizerTrn
from text.symbols import symbols

from oldVersion.V220.models import SynthesizerTrn as V220SynthesizerTrn
from oldVersion.V220.text import symbols as V220symbols
from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
from oldVersion.V210.text import symbols as V210symbols
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
Expand All @@ -30,13 +31,14 @@
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
from oldVersion.V101.text import symbols as V101symbols

from oldVersion import V111, V110, V101, V200, V210
from oldVersion import V111, V110, V101, V200, V210, V220

# 当前版本信息
latest_version = "2.3"

# 版本兼容
SynthesizerTrnMap = {
"2.2": V220SynthesizerTrn,
"2.1": V210SynthesizerTrn,
"2.0.2-fix": V200SynthesizerTrn,
"2.0.1": V200SynthesizerTrn,
Expand All @@ -51,6 +53,7 @@
}

symbolsMap = {
"2.2": V220symbols,
"2.1": V210symbols,
"2.0.2-fix": V200symbols,
"2.0.1": V200symbols,
Expand Down Expand Up @@ -162,6 +165,9 @@ def infer(
style_weight=0.7,
):
# 2.2版本参数位置变了
inferMap_V4 = {
"2.2": V220.infer,
}
# 2.1 参数新增 emotion reference_audio skip_start skip_end
inferMap_V3 = {
"2.1": V210.infer,
Expand All @@ -186,6 +192,26 @@ def infer(
version = hps.version if hasattr(hps, "version") else latest_version
# 非当前版本,根据版本号选择合适的infer
if version != latest_version:
if version in inferMap_V4.keys():
emotion = "" # Use empty emotion prompt
return inferMap_V4[version](
text,
emotion,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio,
skip_start,
skip_end,
style_text,
style_weight,
)
if version in inferMap_V3.keys():
emotion = 0
return inferMap_V3[version](
Expand Down
2 changes: 0 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from commons import init_weights, get_padding
from text import symbols, num_tones, num_languages

from vector_quantize_pytorch import VectorQuantize


class DurationDiscriminator(nn.Module): # vits2
def __init__(
Expand Down
227 changes: 227 additions & 0 deletions oldVersion/V220/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
@Desc: 2.2版本兼容 对应版本 v2.2 Clap-Enhanced prompt audio generation
"""
import numpy as np
import torch
import commons
from .text import cleaned_text_to_sequence, get_bert
from .text.cleaner import clean_text
from .clap_wrapper import get_clap_audio_feature, get_clap_text_feature


def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
# 在此处实现当前版本的get_text
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)

if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_ori = get_bert(
norm_text, word2ph, language_str, device, style_text=None, style_weight=0.7
)
del word2ph
assert bert_ori.shape[-1] == len(phone), phone

if language_str == "ZH":
bert = bert_ori
ja_bert = torch.rand(1024, len(phone))
en_bert = torch.rand(1024, len(phone))
elif language_str == "JP":
bert = torch.rand(1024, len(phone))
ja_bert = bert_ori
en_bert = torch.rand(1024, len(phone))
elif language_str == "EN":
bert = torch.rand(1024, len(phone))
ja_bert = torch.rand(1024, len(phone))
en_bert = bert_ori
else:
raise ValueError("language_str should be ZH, JP or EN")

assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"

phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, en_bert, phone, tone, language


def infer(
text,
emotion,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio=None,
skip_start=False,
skip_end=False,
style_text=None,
style_weight=0.7,
):
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, device)
else:
emo = get_clap_text_feature(emotion, device)
emo = torch.squeeze(emo, dim=1)

bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
text, language, hps, device
)
if skip_start:
phones = phones[3:]
tones = tones[3:]
lang_ids = lang_ids[3:]
bert = bert[:, 3:]
ja_bert = ja_bert[:, 3:]
en_bert = en_bert[:, 3:]
if skip_end:
phones = phones[:-2]
tones = tones[:-2]
lang_ids = lang_ids[:-2]
bert = bert[:, :-2]
ja_bert = ja_bert[:, :-2]
en_bert = en_bert[:, :-2]
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
emo = emo.to(device).unsqueeze(0)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
en_bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio


def infer_multilang(
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
language,
hps,
net_g,
device,
reference_audio=None,
emotion=None,
skip_start=False,
skip_end=False,
):
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
# emo = get_emo_(reference_audio, emotion, sid)
if isinstance(reference_audio, np.ndarray):
emo = get_clap_audio_feature(reference_audio, device)
else:
emo = get_clap_text_feature(emotion, device)
emo = torch.squeeze(emo, dim=1)
for idx, (txt, lang) in enumerate(zip(text, language)):
skip_start = (idx != 0) or (skip_start and idx == 0)
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
(
temp_bert,
temp_ja_bert,
temp_en_bert,
temp_phones,
temp_tones,
temp_lang_ids,
) = get_text(txt, lang, hps, device)
if skip_start:
temp_bert = temp_bert[:, 3:]
temp_ja_bert = temp_ja_bert[:, 3:]
temp_en_bert = temp_en_bert[:, 3:]
temp_phones = temp_phones[3:]
temp_tones = temp_tones[3:]
temp_lang_ids = temp_lang_ids[3:]
if skip_end:
temp_bert = temp_bert[:, :-2]
temp_ja_bert = temp_ja_bert[:, :-2]
temp_en_bert = temp_en_bert[:, :-2]
temp_phones = temp_phones[:-2]
temp_tones = temp_tones[:-2]
temp_lang_ids = temp_lang_ids[:-2]
bert.append(temp_bert)
ja_bert.append(temp_ja_bert)
en_bert.append(temp_en_bert)
phones.append(temp_phones)
tones.append(temp_tones)
lang_ids.append(temp_lang_ids)
bert = torch.concatenate(bert, dim=1)
ja_bert = torch.concatenate(ja_bert, dim=1)
en_bert = torch.concatenate(en_bert, dim=1)
phones = torch.concatenate(phones, dim=0)
tones = torch.concatenate(tones, dim=0)
lang_ids = torch.concatenate(lang_ids, dim=0)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
emo = emo.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
en_bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio
4 changes: 2 additions & 2 deletions clap_gen.py → oldVersion/V220/clap_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import utils
from config import config
from clap_wrapper import get_clap_audio_feature
from .clap_wrapper import get_clap_audio_feature
import librosa
import os

Expand All @@ -27,7 +27,7 @@ def process_line(line):
device = torch.device("cpu")
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")

clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.pt")
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.npy")
if os.path.isfile(clap_path):
return

Expand Down
File renamed without changes.
Loading

0 comments on commit 98f5917

Please sign in to comment.