Skip to content

Commit

Permalink
Fix multiprocess (#201)
Browse files Browse the repository at this point in the history
* fix nultiprocess

* fix format

* fix multiprocess=0
  • Loading branch information
Isotr0py authored Nov 28, 2023
1 parent e92f67f commit edb5d05
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
14 changes: 8 additions & 6 deletions bert_gen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
from multiprocessing import Pool, cpu_count

import torch
from multiprocessing import Pool
import torch.multiprocessing as mp
from tqdm import tqdm

import commons
import utils
from tqdm import tqdm
from text import cleaned_text_to_sequence, get_bert
import argparse
import torch.multiprocessing as mp
from config import config
from text import cleaned_text_to_sequence, get_bert


def process_line(line):
Expand Down Expand Up @@ -64,7 +66,7 @@ def process_line(line):
with open(hps.data.validation_files, encoding="utf-8") as f:
lines.extend(f.readlines())
if len(lines) != 0:
num_processes = args.num_processes
num_processes = min(args.num_processes, cpu_count())
with Pool(processes=num_processes) as pool:
for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
pass
Expand Down
7 changes: 6 additions & 1 deletion emo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def process_func(

wavnames = [line.split("|")[0] for line in lines]
dataset = AudioDataset(wavnames, 16000, processor)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16)
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=min(args.num_processes, os.cpu_count() - 1),
)

with torch.no_grad():
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
Expand Down
7 changes: 7 additions & 0 deletions text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,11 @@ def check_bert_models():
_check_bert(v["repo_id"], v["files"], local_path)


def init_openjtalk():
import pyopenjtalk

pyopenjtalk.g2p("こんにちは,世界。")


init_openjtalk()
check_bert_models()

0 comments on commit edb5d05

Please sign in to comment.