Skip to content

Commit

Permalink
Add model download to train_ms.py (#150)
Browse files Browse the repository at this point in the history
* download base model

* fix download

* fix download repo

* set HF default
  • Loading branch information
Isotr0py authored Nov 5, 2023
1 parent 0dbe362 commit 80c418e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def __init__(
self,
config_path: str,
env: Dict[str, any],
base: Dict[str, any],
model: str,
):
self.env = env # 需要加载的环境变量
self.base = base # 底模配置
self.model = model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
self.config_path = config_path # 配置文件路径

Expand Down
11 changes: 9 additions & 2 deletions default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
# 不填或者填空则路径为相对于项目根目录的路径
dataset_path: "Data/你的数据集"
mirror: "openi" # 模型镜像源
openi_token: "1145141919810" # openi token

# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
mirror: ""
openi_token: "" # openi token

# resample 音频重采样配置
# 注意, “:” 后需要加空格
Expand Down Expand Up @@ -67,6 +69,11 @@ train_ms:
RANK: 0
# 可以填写任意名的环境变量
THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
# 底模设置
base:
use_base_model: false
repo_id: "Stardust_minus/Bert-VITS2"
model_image: "Bert-VITS2中日底模" # openi网页的模型名
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
model: "models"
# 配置文件路径
Expand Down
10 changes: 10 additions & 0 deletions train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def run():
dur_resume_lr = None
if net_dur_disc is not None:
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)

# 下载底模
if config.train_ms_config.base["use_base_model"]:
utils.download_checkpoint(
hps.model_dir,
config.train_ms_config.base,
token=config.openi_token,
mirror=config.mirror,
)

try:
if net_dur_disc is not None:
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
Expand Down
30 changes: 30 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import argparse
import logging
import json
import shutil
import subprocess
import numpy as np
from huggingface_hub import hf_hub_download
from scipy.io.wavfile import read
import torch

Expand All @@ -13,6 +15,34 @@
logger = logging.getLogger(__name__)


def download_checkpoint(
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
):
repo_id = repo_config["repo_id"]
f_list = glob.glob(os.path.join(dir_path, regex))
if f_list:
print("Use existed model, skip downloading.")
return
if mirror.lower() == "openi":
import openi

kwargs = {"token": token} if token else {}
openi.login(**kwargs)

model_image = repo_config["model_image"]
openi.model.download_model(repo_id, model_image, dir_path)

fs = glob.glob(os.path.join(dir_path, model_image, "*.pth"))
for file in fs:
shutil.move(file, dir_path)
shutil.rmtree(os.path.join(dir_path, model_image))
else:
for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
hf_hub_download(
repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False
)


def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
Expand Down

0 comments on commit 80c418e

Please sign in to comment.