Skip to content

Commit

Permalink
Merge pull request #13 from sensein/dev-mbsilva
Browse files Browse the repository at this point in the history
Add scripts for speech-to-text using whisper and stt+forced alignment with whisperX
  • Loading branch information
satra authored Mar 28, 2024
2 parents 8a83a06 + 4a31e77 commit 1c1ed04
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
requires-python = ">=3.10"
requires-python = ">=3.10, <3.12"
dependencies = [
"speechbrain>=1.0.0",
"torchaudio>=2.0.0",
"opensmile>=2.3.0",
"matplotlib>=3.8.3",
"click",
"whisperx @ git+https://github.com/m-bain/whisperx.git@f2da2f858e99e4211fe4f64b5f2938b007827e17",
"pydra~=0.23",
"TTS",
"accelerate",
Expand All @@ -35,6 +36,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pytest",
"pre-commit"
]

[project.scripts]
Expand Down
74 changes: 74 additions & 0 deletions src/b2aiprep/speech2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import typing as ty

import torch
import whisperx

from .process import Audio


# Transcribes speech to text using the whisperX model
def transcribe_audio_whisperx(
audio: Audio,
hf_token: ty.Optional[str] = None,
model: str = "base",
device: ty.Optional[str] = None,
batch_size: int = 16,
compute_type: ty.Optional[str] = None,
force_alignment: bool = True,
return_char_alignments: bool = False,
diarize: bool = False,
min_speakers: ty.Optional[int] = None,
max_speakers: ty.Optional[int] = None,
):
"""
Transcribes audio to text using OpenAI's whisper model.
Args:
audio (audio): Audio object.
model (str): Model to use for transcription. Defaults to "base".
See https://github.com/openai/whisper/ for a list of all available models.
device (str): Device to use for computation. Defaults to "cuda".
batch_size (int): Batch size for transcription. Defaults to 16.
compute_type (str): Type of computation to use. Defaults to "float16".
Change to "int8" if low on GPU mem (may reduce accuracy)
force_alignment (bool): Whether or not to perform forced alignment of the
speech-to-text output
diarize (bool): Whether or not to assign speaker labels to the text
hf_token (str): A Huggingface auth token, required to perform speaker diarization
Returns:
Result of the transcription.
"""

# 1. Transcribe with original whisper (batched)
device = device or "cuda" if torch.cuda.is_available() else "cpu"
model = whisperx.load_model(model, device, compute_type=compute_type)

if audio.sample_rate != 16000:
audio = audio.to_16khz()
audio = audio.signal.squeeze().numpy()
result = model.transcribe(audio, batch_size=batch_size)

if force_alignment:
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=return_char_alignments,
)

if diarize:
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
return result
16 changes: 16 additions & 0 deletions src/tests/test_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from b2aiprep.process import Audio, SpeechToText
from b2aiprep.speech2text import transcribe_audio_whisperx


def test_transcribe():
Expand All @@ -22,6 +23,21 @@ def test_transcribe():
assert text.strip() == audio_content


def test_transcribe_whisperx():
"""
Validates SpeechToText's ability to convert audio to text accurately.
Checks if the transcription matches the expected output, considering known model discrepancies.
"""
audio_path = str((Path(__file__).parent.parent.parent / "data/vc_source.wav").absolute())
audio_content = "If it isn't, it isn't."

# Note: Should be "If it didn't, it didn't.", but that's what the model understands
audio = Audio.from_file(audio_path)

result = transcribe_audio_whisperx(audio, model="tiny", device="cpu", compute_type="float32")
assert result["segments"][0]["text"].strip() == audio_content


def test_cuda_not_available():
"""
Test behavior when CUDA is not available.
Expand Down

0 comments on commit 1c1ed04

Please sign in to comment.