forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix: support infer 2.2 models (fishaudio#244)
* Fix: support infer 2.2 models * Fix imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
7ebc1aa
commit 98f5917
Showing
30 changed files
with
133,753 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
""" | ||
@Desc: 2.2版本兼容 对应版本 v2.2 Clap-Enhanced prompt audio generation | ||
""" | ||
import numpy as np | ||
import torch | ||
import commons | ||
from .text import cleaned_text_to_sequence, get_bert | ||
from .text.cleaner import clean_text | ||
from .clap_wrapper import get_clap_audio_feature, get_clap_text_feature | ||
|
||
|
||
def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7): | ||
# 在此处实现当前版本的get_text | ||
norm_text, phone, tone, word2ph = clean_text(text, language_str) | ||
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) | ||
|
||
if hps.data.add_blank: | ||
phone = commons.intersperse(phone, 0) | ||
tone = commons.intersperse(tone, 0) | ||
language = commons.intersperse(language, 0) | ||
for i in range(len(word2ph)): | ||
word2ph[i] = word2ph[i] * 2 | ||
word2ph[0] += 1 | ||
bert_ori = get_bert( | ||
norm_text, word2ph, language_str, device, style_text=None, style_weight=0.7 | ||
) | ||
del word2ph | ||
assert bert_ori.shape[-1] == len(phone), phone | ||
|
||
if language_str == "ZH": | ||
bert = bert_ori | ||
ja_bert = torch.rand(1024, len(phone)) | ||
en_bert = torch.rand(1024, len(phone)) | ||
elif language_str == "JP": | ||
bert = torch.rand(1024, len(phone)) | ||
ja_bert = bert_ori | ||
en_bert = torch.rand(1024, len(phone)) | ||
elif language_str == "EN": | ||
bert = torch.rand(1024, len(phone)) | ||
ja_bert = torch.rand(1024, len(phone)) | ||
en_bert = bert_ori | ||
else: | ||
raise ValueError("language_str should be ZH, JP or EN") | ||
|
||
assert bert.shape[-1] == len( | ||
phone | ||
), f"Bert seq len {bert.shape[-1]} != {len(phone)}" | ||
|
||
phone = torch.LongTensor(phone) | ||
tone = torch.LongTensor(tone) | ||
language = torch.LongTensor(language) | ||
return bert, ja_bert, en_bert, phone, tone, language | ||
|
||
|
||
def infer( | ||
text, | ||
emotion, | ||
sdp_ratio, | ||
noise_scale, | ||
noise_scale_w, | ||
length_scale, | ||
sid, | ||
language, | ||
hps, | ||
net_g, | ||
device, | ||
reference_audio=None, | ||
skip_start=False, | ||
skip_end=False, | ||
style_text=None, | ||
style_weight=0.7, | ||
): | ||
if isinstance(reference_audio, np.ndarray): | ||
emo = get_clap_audio_feature(reference_audio, device) | ||
else: | ||
emo = get_clap_text_feature(emotion, device) | ||
emo = torch.squeeze(emo, dim=1) | ||
|
||
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( | ||
text, language, hps, device | ||
) | ||
if skip_start: | ||
phones = phones[3:] | ||
tones = tones[3:] | ||
lang_ids = lang_ids[3:] | ||
bert = bert[:, 3:] | ||
ja_bert = ja_bert[:, 3:] | ||
en_bert = en_bert[:, 3:] | ||
if skip_end: | ||
phones = phones[:-2] | ||
tones = tones[:-2] | ||
lang_ids = lang_ids[:-2] | ||
bert = bert[:, :-2] | ||
ja_bert = ja_bert[:, :-2] | ||
en_bert = en_bert[:, :-2] | ||
with torch.no_grad(): | ||
x_tst = phones.to(device).unsqueeze(0) | ||
tones = tones.to(device).unsqueeze(0) | ||
lang_ids = lang_ids.to(device).unsqueeze(0) | ||
bert = bert.to(device).unsqueeze(0) | ||
ja_bert = ja_bert.to(device).unsqueeze(0) | ||
en_bert = en_bert.to(device).unsqueeze(0) | ||
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) | ||
emo = emo.to(device).unsqueeze(0) | ||
del phones | ||
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) | ||
audio = ( | ||
net_g.infer( | ||
x_tst, | ||
x_tst_lengths, | ||
speakers, | ||
tones, | ||
lang_ids, | ||
bert, | ||
ja_bert, | ||
en_bert, | ||
emo, | ||
sdp_ratio=sdp_ratio, | ||
noise_scale=noise_scale, | ||
noise_scale_w=noise_scale_w, | ||
length_scale=length_scale, | ||
)[0][0, 0] | ||
.data.cpu() | ||
.float() | ||
.numpy() | ||
) | ||
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
return audio | ||
|
||
|
||
def infer_multilang( | ||
text, | ||
sdp_ratio, | ||
noise_scale, | ||
noise_scale_w, | ||
length_scale, | ||
sid, | ||
language, | ||
hps, | ||
net_g, | ||
device, | ||
reference_audio=None, | ||
emotion=None, | ||
skip_start=False, | ||
skip_end=False, | ||
): | ||
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], [] | ||
# emo = get_emo_(reference_audio, emotion, sid) | ||
if isinstance(reference_audio, np.ndarray): | ||
emo = get_clap_audio_feature(reference_audio, device) | ||
else: | ||
emo = get_clap_text_feature(emotion, device) | ||
emo = torch.squeeze(emo, dim=1) | ||
for idx, (txt, lang) in enumerate(zip(text, language)): | ||
skip_start = (idx != 0) or (skip_start and idx == 0) | ||
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1) | ||
( | ||
temp_bert, | ||
temp_ja_bert, | ||
temp_en_bert, | ||
temp_phones, | ||
temp_tones, | ||
temp_lang_ids, | ||
) = get_text(txt, lang, hps, device) | ||
if skip_start: | ||
temp_bert = temp_bert[:, 3:] | ||
temp_ja_bert = temp_ja_bert[:, 3:] | ||
temp_en_bert = temp_en_bert[:, 3:] | ||
temp_phones = temp_phones[3:] | ||
temp_tones = temp_tones[3:] | ||
temp_lang_ids = temp_lang_ids[3:] | ||
if skip_end: | ||
temp_bert = temp_bert[:, :-2] | ||
temp_ja_bert = temp_ja_bert[:, :-2] | ||
temp_en_bert = temp_en_bert[:, :-2] | ||
temp_phones = temp_phones[:-2] | ||
temp_tones = temp_tones[:-2] | ||
temp_lang_ids = temp_lang_ids[:-2] | ||
bert.append(temp_bert) | ||
ja_bert.append(temp_ja_bert) | ||
en_bert.append(temp_en_bert) | ||
phones.append(temp_phones) | ||
tones.append(temp_tones) | ||
lang_ids.append(temp_lang_ids) | ||
bert = torch.concatenate(bert, dim=1) | ||
ja_bert = torch.concatenate(ja_bert, dim=1) | ||
en_bert = torch.concatenate(en_bert, dim=1) | ||
phones = torch.concatenate(phones, dim=0) | ||
tones = torch.concatenate(tones, dim=0) | ||
lang_ids = torch.concatenate(lang_ids, dim=0) | ||
with torch.no_grad(): | ||
x_tst = phones.to(device).unsqueeze(0) | ||
tones = tones.to(device).unsqueeze(0) | ||
lang_ids = lang_ids.to(device).unsqueeze(0) | ||
bert = bert.to(device).unsqueeze(0) | ||
ja_bert = ja_bert.to(device).unsqueeze(0) | ||
en_bert = en_bert.to(device).unsqueeze(0) | ||
emo = emo.to(device).unsqueeze(0) | ||
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) | ||
del phones | ||
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) | ||
audio = ( | ||
net_g.infer( | ||
x_tst, | ||
x_tst_lengths, | ||
speakers, | ||
tones, | ||
lang_ids, | ||
bert, | ||
ja_bert, | ||
en_bert, | ||
emo, | ||
sdp_ratio=sdp_ratio, | ||
noise_scale=noise_scale, | ||
noise_scale_w=noise_scale_w, | ||
length_scale=length_scale, | ||
)[0][0, 0] | ||
.data.cpu() | ||
.float() | ||
.numpy() | ||
) | ||
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
return audio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.