Skip to content

Commit

Permalink
Fix backend (#627)
Browse files Browse the repository at this point in the history
* Linux pyaudio dependencies

* revert generate.py

* Better bug report & feat request

* Auto-select torchaudio backend

* safety

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: manual seed for restore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Gradio > 5

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] authored Oct 19, 2024
1 parent ecaa69e commit e37a445
Show file tree
Hide file tree
Showing 16 changed files with 185 additions and 109 deletions.
2 changes: 1 addition & 1 deletion docs/en/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ python -m tools.webui \
```

!!! note
You can save the label file and reference audio file in advance to the examples folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.

!!! note
You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.
Expand Down
2 changes: 1 addition & 1 deletion docs/ja/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ python -m tools.webui \
```

!!! note
ラベルファイルと参照音声ファイルをメインディレクトリの examples フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。

!!! note
Gradio 環境変数(`GRADIO_SHARE``GRADIO_SERVER_PORT``GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。
Expand Down
2 changes: 1 addition & 1 deletion docs/pt/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ python -m tools.webui \
```

!!! note
Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta examples do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta `references` do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.

!!! note
É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI.
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ python -m tools.webui \
```

!!! note
你可以提前将label文件和参考音频文件保存到主目录下的examples文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。

!!! note
你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.
Expand Down
5 changes: 4 additions & 1 deletion fish_speech/models/text2semantic/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,10 @@ def from_pretrained(
model = simple_quantizer.convert_for_runtime()

weights = torch.load(
Path(path) / "model.pth", map_location="cpu", mmap=True
Path(path) / "model.pth",
map_location="cpu",
mmap=True,
weights_only=True,
)

if "state_dict" in weights:
Expand Down
3 changes: 2 additions & 1 deletion fish_speech/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .logger import RankedLogger
from .logging_utils import log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .utils import extras, get_metric_value, task_wrapper
from .utils import extras, get_metric_value, set_seed, task_wrapper

__all__ = [
"enforce_tags",
Expand All @@ -20,4 +20,5 @@
"braceexpand",
"get_latest_checkpoint",
"autocast_exclude_mps",
"set_seed",
]
22 changes: 22 additions & 0 deletions fish_speech/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import random
import warnings
from importlib.util import find_spec
from typing import Callable

import numpy as np
import torch
from omegaconf import DictConfig

from .logger import RankedLogger
Expand Down Expand Up @@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")

return metric_value


def set_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
2 changes: 1 addition & 1 deletion fish_speech/webui/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_shadow="*shadow_drop_lg",
# button_shadow="*shadow_drop_lg",
button_small_padding="0px",
button_large_padding="3px",
)
2 changes: 1 addition & 1 deletion fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def llama_quantify(llama_weight, quantify_mode):
value="VQGAN",
)
with gr.Row():
with gr.Tabs():
with gr.Column():
with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
gr.HTML("You don't need to train this model!")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"einops>=0.7.0",
"librosa>=0.10.1",
"rich>=13.5.3",
"gradio<5.0.0",
"gradio>5.0.0",
"wandb>=0.15.11",
"grpcio>=1.58.0",
"kui>=1.6.0",
Expand Down
19 changes: 14 additions & 5 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# from fish_speech.models.vqgan.lit_module import VQGAN
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps
from fish_speech.utils import autocast_exclude_mps, set_seed
from tools.commons import ServeTTSRequest
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.llama.generate import (
Expand All @@ -46,6 +46,14 @@
)
from tools.vqgan.inference import load_model as load_decoder_model

backends = torchaudio.list_audio_backends()
if "sox" in backends:
backend = "sox"
elif "ffmpeg" in backends:
backend = "ffmpeg"
else:
backend = "soundfile"


def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
Expand Down Expand Up @@ -88,10 +96,7 @@ def load_audio(reference_audio, sr):
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)

waveform, original_sr = torchaudio.load(
reference_audio,
backend="soundfile", # not every linux release supports 'sox' or 'ffmpeg'
)
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)

if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
Expand Down Expand Up @@ -215,6 +220,10 @@ def inference(req: ServeTTSRequest):
else:
logger.info("Use same references")

if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")

# LLAMA Inference
request = dict(
device=decoder_model.device,
Expand Down
1 change: 1 addition & 0 deletions tools/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ServeTTSRequest(BaseModel):
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on-demand", "never"] = "never"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
Expand Down
7 changes: 7 additions & 0 deletions tools/post_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def parse_args():
default="never",
help="Cache encoded references codes in memory",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="None means randomized inference, otherwise deterministic",
)

return parser.parse_args()

Expand Down Expand Up @@ -155,6 +161,7 @@ def parse_args():
"emotion": args.emotion,
"streaming": args.streaming,
"use_memory_cache": args.use_memory_cache,
"seed": args.seed,
}

pydantic_data = ServeTTSRequest(**data)
Expand Down
9 changes: 8 additions & 1 deletion tools/vqgan/extract_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
# This file is used to convert the audio files to text files using the Whisper model.
# It's mainly used to generate the training data for the VQ model.

backends = torchaudio.list_audio_backends()
if "sox" in backends:
backend = "sox"
elif "ffmpeg" in backends:
backend = "ffmpeg"
else:
backend = "soundfile"

RANK = int(os.environ.get("SLURM_PROCID", 0))
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
Expand Down Expand Up @@ -81,7 +88,7 @@ def process_batch(files: list[Path], model) -> float:
for file in files:
try:
wav, sr = torchaudio.load(
str(file), backend="sox" if sys.platform == "linux" else "soundfile"
str(file), backend=backend
) # Need to install libsox-dev
except Exception as e:
logger.error(f"Error reading {file}: {e}")
Expand Down
5 changes: 2 additions & 3 deletions tools/vqgan/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):

model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path,
map_location=device,
checkpoint_path, map_location=device, mmap=True, weights_only=True
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
Expand All @@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
if "generator." in k
}

result = model.load_state_dict(state_dict, strict=False)
result = model.load_state_dict(state_dict, strict=False, assign=True)
model.eval()
model.to(device)

Expand Down
Loading

0 comments on commit e37a445

Please sign in to comment.