Skip to content

Commit

Permalink
fix emo
Browse files Browse the repository at this point in the history
  • Loading branch information
Stardust-minus authored Nov 25, 2023
1 parent afa21ff commit 2e73312
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
特殊版本说明:
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
1.1.1-dev: dev开发
2.0:当前版本
2.1:当前版本
"""
import torch
import commons
Expand Down Expand Up @@ -256,47 +256,54 @@ def infer_multilang(
hps,
net_g,
device,
reference_audio=None,
emotion=None,
skip_start=False,
skip_end=False,
):
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
bert, ja_bert, en_bert, emo, phones, tones, lang_ids = [], [], [], [], [], [], []
# bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
# text, language, hps, device
# )
for idx, (t, l) in enumerate(zip(text, language)):
for idx, (txt, lang, ref, emot) in enumerate(zip(text, language, reference_audio, emotion)):
skip_start = (idx != 0) or (skip_start and idx == 0)
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
(
temp_bert,
temp_ja_bert,
temp_en_bert,
temp_emo,
temp_phones,
temp_tones,
temp_lang_ids,
) = get_text(t, l, hps, device)
) = get_text(txt, ref, emot, lang, hps, device)
if skip_start:
temp_bert = temp_bert[:, 1:]
temp_ja_bert = temp_ja_bert[:, 1:]
temp_en_bert = temp_en_bert[:, 1:]
temp_emo = temp_emo[:, 1:]
temp_phones = temp_phones[1:]
temp_tones = temp_tones[1:]
temp_lang_ids = temp_lang_ids[1:]
if skip_end:
temp_bert = temp_bert[:, :-1]
temp_ja_bert = temp_ja_bert[:, :-1]
temp_en_bert = temp_en_bert[:, :-1]
temp_emo = temp_emo[:, :-1]
temp_phones = temp_phones[:-1]
temp_tones = temp_tones[:-1]
temp_lang_ids = temp_lang_ids[:-1]
bert.append(temp_bert)
ja_bert.append(temp_ja_bert)
en_bert.append(temp_en_bert)
emo.append(temo_emo)
phones.append(temp_phones)
tones.append(temp_tones)
lang_ids.append(temp_lang_ids)
bert = torch.concatenate(bert, dim=1)
ja_bert = torch.concatenate(ja_bert, dim=1)
en_bert = torch.concatenate(en_bert, dim=1)
emo = torch.concatenate(emo, dim=1)
phones = torch.concatenate(phones, dim=0)
tones = torch.concatenate(tones, dim=0)
lang_ids = torch.concatenate(lang_ids, dim=0)
Expand Down

0 comments on commit 2e73312

Please sign in to comment.