diff --git a/MANIFEST.in b/MANIFEST.in index 8a103dd6..6f6187c0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ include faster_whisper/assets/silero_vad.onnx include requirements.txt include requirements.conversion.txt -include faster_whisper/assets/pyannote_vad_model.bin diff --git a/README.md b/README.md index f7d54ee4..e57edbf3 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en") * Python 3.8 or greater +Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package. ### GPU @@ -165,35 +166,6 @@ for segment in segments: segments, _ = model.transcribe("audio.mp3") segments = list(segments) # The transcription will actually run here. ``` - -### multi-segment language detection - -To directly use the model for improved language detection, the following code snippet can be used: - -```python -from faster_whisper import WhisperModel -model = WhisperModel("medium", device="cuda", compute_type="float16") -language_info = model.detect_language_multi_segment("audio.mp3") -``` - -### Batched faster-whisper - - -The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference. - -The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper. - -```python -from faster_whisper import WhisperModel, BatchedInferencePipeline - -model = WhisperModel("medium", device="cuda", compute_type="float16") -batched_model = BatchedInferencePipeline(model=model) -segments, info = batched_model.transcribe("audio.mp3", batch_size=16) - -for segment in segments: - print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) -``` - ### Faster Distil-Whisper The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3) diff --git a/benchmark/wer_benchmark.py b/benchmark/wer_benchmark.py index f7a0b792..bf0a1e0e 100644 --- a/benchmark/wer_benchmark.py +++ b/benchmark/wer_benchmark.py @@ -1,6 +1,5 @@ import argparse import json -import os from datasets import load_dataset from evaluate import load @@ -27,9 +26,7 @@ # define the evaluation metric wer_metric = load("wer") - -with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: - normalizer = EnglishTextNormalizer(json.load(f)) +normalizer = EnglishTextNormalizer(json.load(open("normalizer.json"))) def inference(batch): diff --git a/faster_whisper/__init__.py b/faster_whisper/__init__.py index ad692778..9b56a393 100644 --- a/faster_whisper/__init__.py +++ b/faster_whisper/__init__.py @@ -1,5 +1,5 @@ from faster_whisper.audio import decode_audio -from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel +from faster_whisper.transcribe import WhisperModel from faster_whisper.utils import available_models, download_model, format_timestamp from faster_whisper.version import __version__ @@ -7,7 +7,6 @@ "available_models", "decode_audio", "WhisperModel", - "BatchedInferencePipeline", "download_model", "format_timestamp", "__version__", diff --git a/faster_whisper/assets/pyannote_vad_model.bin b/faster_whisper/assets/pyannote_vad_model.bin deleted file mode 100644 index 75c92f09..00000000 Binary files a/faster_whisper/assets/pyannote_vad_model.bin and /dev/null differ diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 7ae68d40..a597fd83 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -1,7 +1,19 @@ +"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV + +The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional +system dependencies. FFmpeg does not need to be installed on the system. + +However, the API is quite low-level so we need to manipulate audio frames directly. +""" + +import gc +import io +import itertools + from typing import BinaryIO, Union -import torch -import torchaudio +import av +import numpy as np def decode_audio( @@ -17,42 +29,91 @@ def decode_audio( split_stereo: Return separate left and right channels. Returns: - A float32 Torch Tensor. + A float32 Numpy array. If `split_stereo` is enabled, the function returns a 2-tuple with the separated left and right channels. """ + resampler = av.audio.resampler.AudioResampler( + format="s16", + layout="mono" if not split_stereo else "stereo", + rate=sampling_rate, + ) + + raw_buffer = io.BytesIO() + dtype = None - waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T + with av.open(input_file, mode="r", metadata_errors="ignore") as container: + frames = container.decode(audio=0) + frames = _ignore_invalid_frames(frames) + frames = _group_frames(frames, 500000) + frames = _resample_frames(frames, resampler) + + for frame in frames: + array = frame.to_ndarray() + dtype = array.dtype + raw_buffer.write(array) + + # It appears that some objects related to the resampler are not freed + # unless the garbage collector is manually run. + del resampler + gc.collect() + + audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) + + # Convert s16 back to f32. + audio = audio.astype(np.float32) / 32768.0 - if audio_sf != sampling_rate: - waveform = torchaudio.functional.resample( - waveform, orig_freq=audio_sf, new_freq=sampling_rate - ) if split_stereo: - return waveform[0], waveform[1] + left_channel = audio[0::2] + right_channel = audio[1::2] + return left_channel, right_channel + + return audio + + +def _ignore_invalid_frames(frames): + iterator = iter(frames) + + while True: + try: + yield next(iterator) + except StopIteration: + break + except av.error.InvalidDataError: + continue + + +def _group_frames(frames, num_samples=None): + fifo = av.audio.fifo.AudioFifo() + + for frame in frames: + frame.pts = None # Ignore timestamp check. + fifo.write(frame) + + if num_samples is not None and fifo.samples >= num_samples: + yield fifo.read() + + if fifo.samples > 0: + yield fifo.read() + - return waveform.mean(0) +def _resample_frames(frames, resampler): + # Add None to flush the resampler. + for frame in itertools.chain(frames, [None]): + yield from resampler.resample(frame) def pad_or_trim(array, length: int, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. """ - axis = axis % array.ndim if array.shape[axis] > length: - idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1) - return array[idx] + array = array.take(indices=range(length), axis=axis) if array.shape[axis] < length: - pad_widths = ( - [ - 0, - ] - * array.ndim - * 2 - ) - pad_widths[2 * axis] = length - array.shape[axis] - array = torch.nn.functional.pad(array, tuple(pad_widths[::-1])) + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) return array diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 6371d5ef..0aa15070 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -1,21 +1,16 @@ -import torch +import numpy as np # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501 class FeatureExtractor: def __init__( self, - device: str = "auto", feature_size=80, sampling_rate=16000, hop_length=160, chunk_length=30, n_fft=400, ): - if device == "auto": - self.device = "cuda" if torch.cuda.is_available() else "cpu" - else: - self.device = device self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length @@ -27,22 +22,21 @@ def __init__( sampling_rate, n_fft, n_mels=feature_size ) - @staticmethod - def get_mel_filters(sr, n_fft, n_mels=128): - """ - Implementation of librosa.filters.mel in Pytorch - """ + def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): # Initialize the weights n_mels = int(n_mels) + weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) # Center freqs of each FFT bin - fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr) + fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = 0.0 max_mel = 45.245640471924965 - mels = torch.linspace(min_mel, max_mel, n_mels + 2) + mels = np.linspace(min_mel, max_mel, n_mels + 2) + + mels = np.asanyarray(mels) # Fill in the linear scale f_min = 0.0 @@ -52,63 +46,125 @@ def get_mel_filters(sr, n_fft, n_mels=128): # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region + logstep = np.log(6.4) / 27.0 # step size for log region # If we have vector data, vectorize log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) + freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) mel_f = freqs - fdiff = torch.diff(mel_f) - ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1) + fdiff = np.diff(mel_f) + ramps = np.subtract.outer(mel_f, fftfreqs) - lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1) - upper = ramps[2:] / fdiff[1:].unsqueeze(1) + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] - # Intersect them with each other and zero, vectorized across all i - weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper)) + # .. then intersect them with each other and zero + weights[i] = np.maximum(0, np.minimum(lower, upper)) # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) - weights *= enorm.unsqueeze(1) + weights *= enorm[:, np.newaxis] return weights - def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): + def fram_wave(self, waveform, center=True): + """ + Transform a raw waveform into a list of smaller waveforms. + The window length defines how much of the signal is + contain in each frame (smalle waveform), while the hope length defines the step + between the beginning of each new frame. + Centering is done by reflecting the waveform which is first centered around + `frame_idx * hop_length`. """ - Compute the log-Mel spectrogram of the provided audio. + frames = [] + for i in range(0, waveform.shape[0] + 1, self.hop_length): + half_window = (self.n_fft - 1) // 2 + 1 + if center: + start = i - half_window if i > half_window else 0 + end = ( + i + half_window + if i < waveform.shape[0] - half_window + else waveform.shape[0] + ) + + frame = waveform[start:end] + + if start == 0: + padd_width = (-i + half_window, 0) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + elif end == waveform.shape[0]: + padd_width = (0, (i - waveform.shape[0] + half_window)) + frame = np.pad(frame, pad_width=padd_width, mode="reflect") + + else: + frame = waveform[i : i + self.n_fft] + frame_width = frame.shape[0] + if frame_width < waveform.shape[0]: + frame = np.lib.pad( + frame, + pad_width=(0, self.n_fft - frame_width), + mode="constant", + constant_values=0, + ) + + frames.append(frame) + return np.stack(frames, 0) + + def stft(self, frames, window): """ + Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. + Should give the same results as `torch.stft`. + """ + frame_size = frames.shape[1] + fft_size = self.n_fft + + if fft_size is None: + fft_size = frame_size + + if fft_size < frame_size: + raise ValueError("FFT size must greater or equal the frame size") + # number of FFT bins to store + num_fft_bins = (fft_size >> 1) + 1 + + data = np.empty((len(frames), num_fft_bins), dtype=np.complex64) + fft_signal = np.zeros(fft_size) + for f, frame in enumerate(frames): + if window is not None: + np.multiply(frame, window, out=fft_signal[:frame_size]) + else: + fft_signal[:frame_size] = frame + data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] + return data.T + + def __call__(self, waveform, padding=True, chunk_length=None): + """ + Compute the log-Mel spectrogram of the provided audio, gives similar results + whisper's original torch implementation with 1e-5 tolerance. + """ if chunk_length is not None: self.n_samples = chunk_length * self.sampling_rate self.nb_max_frames = self.n_samples // self.hop_length - if waveform.dtype is not torch.float32: - waveform = waveform.to(torch.float32) - - waveform = ( - waveform.to(self.device) - if self.device == "cuda" and not waveform.is_cuda - else waveform - ) - if padding: - waveform = torch.nn.functional.pad(waveform, (0, self.n_samples)) + waveform = np.pad(waveform, [(0, self.n_samples)]) - window = torch.hann_window(self.n_fft).to(waveform.device) + window = np.hanning(self.n_fft + 1)[:-1] - stft = torch.stft( - waveform, self.n_fft, self.hop_length, window=window, return_complex=True - ) - magnitudes = stft[..., :-1].abs() ** 2 + frames = self.fram_wave(waveform) + stft = self.stft(frames, window=window) + magnitudes = np.abs(stft[:, :-1]) ** 2 - mel_spec = self.mel_filters.to(waveform.device) @ magnitudes + filters = self.mel_filters + mel_spec = filters @ magnitudes - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) + log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - # When the model is running on multiple GPUs, the output should be moved - # to the CPU since we don't know which GPU will handle the next job. - return log_spec.cpu() if to_cpu else log_spec + return log_spec diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a37ae388..44959b16 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -2,38 +2,24 @@ import json import logging import os -import random import zlib -from collections import Counter, defaultdict from inspect import signature from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union import ctranslate2 import numpy as np import tokenizers -import torch - -from pyannote.audio import Model -from tqdm import tqdm from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer -from faster_whisper.utils import ( - download_model, - format_timestamp, - get_assets_path, - get_end, - get_logger, -) +from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, VadOptions, - VoiceActivitySegmentation, collect_chunks, get_speech_timestamps, - merge_chunks, ) @@ -51,14 +37,13 @@ class Segment(NamedTuple): end: float text: str tokens: List[int] + temperature: float avg_logprob: float compression_ratio: float no_speech_prob: float words: Optional[List[Word]] - temperature: Optional[float] = 1.0 -# Added additional parameters for multilingual videos and fixes below class TranscriptionOptions(NamedTuple): beam_size: int best_of: int @@ -67,7 +52,6 @@ class TranscriptionOptions(NamedTuple): repetition_penalty: float no_repeat_ngram_size: int log_prob_threshold: Optional[float] - log_prob_low_threshold: Optional[float] no_speech_threshold: Optional[float] compression_ratio_threshold: Optional[float] condition_on_previous_text: bool @@ -82,8 +66,6 @@ class TranscriptionOptions(NamedTuple): word_timestamps: bool prepend_punctuations: str append_punctuations: str - multilingual: bool - output_language: Optional[str] max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] @@ -100,486 +82,6 @@ class TranscriptionInfo(NamedTuple): vad_options: VadOptions -# The code below is originally from HF pipeline and is used in whisper-x -# (https://github.com/m-bain/whisperX) and adapted for faster_whisper - - -class BatchedInferencePipeline: - """ - Huggingface Pipeline wrapper for WhisperModel. - Copyright (c) 2022, Max Bain - All rights reserved. - Modified by Mobius Labs GmbH - """ - - def __init__( - self, - model, - use_vad_model: bool = True, - options: Optional[NamedTuple] = None, - tokenizer=None, - chunk_length: int = 30, - vad_device: Union[int, str, "torch.device"] = "auto", - vad_onset: float = 0.500, - vad_offset: float = 0.363, - language: Optional[str] = None, - ): - self.model: WhisperModel = model - self.tokenizer = tokenizer - self.options = options - self.preset_language = language - self.use_vad_model = use_vad_model - self.vad_onset = vad_onset - self.vad_offset = vad_offset - self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin") - if self.use_vad_model: - self.vad_device = self.get_device(vad_device) - self.vad_model = self.load_vad_model( - vad_onset=self.vad_onset, vad_offset=self.vad_offset - ) - else: - self.vad_model = None - self.chunk_length = chunk_length # VAD merging size - self.last_speech_timestamp = 0.0 - - def get_device(self, device: Union[int, str, "torch.device"]): - """ - Converts the input device into a torch.device object. - - The input can be an integer, a string, or a `torch.device` object. - - The function handles a special case where the input device is "auto". - When "auto" is specified, the device will default to the - device of the model (self.model.device). If the model's device is also "auto", - it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu". - """ - if isinstance(device, torch.device): - return device - elif isinstance(device, str): - if device == "auto" and self.model.device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - elif device == "auto": - device = self.model.device - return torch.device(device) - elif device < 0: - return torch.device("cpu") - else: - return torch.device(f"cuda:{device}") - - def forward(self, features, segments_metadata, **forward_params): - encoder_output, outputs = self.model.generate_segment_batched( - features, self.tokenizer, forward_params - ) - - segmented_outputs = [] - segment_sizes = [] - for segment_metadata, output in zip(segments_metadata, outputs): - duration = segment_metadata["end_time"] - segment_metadata["start_time"] - segment_size = int(duration * self.model.frames_per_second) - segment_sizes.append(segment_size) - ( - subsegments, - seek, - single_timestamp_ending, - ) = self.model._split_segments_by_timestamps( - tokenizer=self.tokenizer, - tokens=output["tokens"], - time_offset=segment_metadata["start_time"], - segment_size=segment_size, - segment_duration=duration, - seek=0, - ) - segmented_outputs.append( - [ - dict( - text=self.tokenizer.decode(subsegment["tokens"]), - avg_logprob=output["avg_logprob"], - no_speech_prob=output["no_speech_prob"], - tokens=subsegment["tokens"], - start=subsegment["start"], - end=subsegment["end"], - compression_ratio=get_compression_ratio( - self.tokenizer.decode(subsegment["tokens"]) - ), - ) - for subsegment in subsegments - ] - ) - if forward_params["word_timestamps"]: - self.last_speech_timestamp = self.model.add_word_timestamps( - segmented_outputs, - self.tokenizer, - encoder_output, - segment_sizes, - forward_params["prepend_punctuations"], - forward_params["append_punctuations"], - self.last_speech_timestamp, - ) - - return segmented_outputs - - def get_language_and_tokenizer( - self, audio, task: Optional[str] = None, language: Optional[str] = None - ): - all_language_probs = None - language_probability = 1.0 - - if self.tokenizer is None: - if not language: - ( - language, - language_probability, - all_language_probs, - ) = self.model.detect_language(audio) - task = task or "transcribe" - self.tokenizer = Tokenizer( - self.model.hf_tokenizer, - self.model.model.is_multilingual, - task=task, - language=language, - ) - else: - if task is not None: - self.tokenizer.task = self.tokenizer.tokenizer.token_to_id( - f"<|{task}|>" - ) - - if language is not None: - self.tokenizer.language = self.tokenizer.tokenizer.token_to_id( - f"<|{language}|>" - ) - self.tokenizer.language_code = language - - return language, language_probability, task, all_language_probs - - @staticmethod - def audio_split(audio, segments, sampling_rate): - """Returns splitted audio chunks as iterator""" - audio_segments = [] - segments_metadata = [] - for seg in segments: - f1 = int(seg["start"] * sampling_rate) - f2 = int(seg["end"] * sampling_rate) - seg_metadata = { - "start_time": seg["start"], - "end_time": seg["end"], - "stitched_seg": seg["segments"], - } - audio_segments.append(audio[f1:f2]) - segments_metadata.append(seg_metadata) - return audio_segments, segments_metadata - - def load_vad_model(self, vad_onset=0.500, vad_offset=0.363): - vad_model = Model.from_pretrained(self.vad_model_path) - hyperparameters = { - "onset": vad_onset, - "offset": vad_offset, - "min_duration_on": 0.1, - "min_duration_off": 0.1, - } - - vad_pipeline = VoiceActivitySegmentation( - segmentation=vad_model, device=torch.device(self.vad_device) - ) - vad_pipeline.instantiate(hyperparameters) - return vad_pipeline - - def transcribe( - self, - audio: Union[str, torch.Tensor, np.ndarray], - vad_segments: Optional[List[dict]] = None, - batch_size: int = 16, - language: Optional[str] = None, - task: str = None, - log_progress: bool = False, - beam_size: int = 5, - best_of: int = 5, - patience: float = 1, - length_penalty: float = 1, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - temperature: Union[float, List[float], Tuple[float, ...]] = [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - ], - compression_ratio_threshold: Optional[float] = 2.4, - log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = 0.6, - initial_prompt: Optional[Union[str, Iterable[int]]] = None, - prefix: Optional[str] = None, - suppress_blank: bool = True, - suppress_tokens: Optional[List[int]] = [-1], - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - max_new_tokens: Optional[int] = None, - hotwords: Optional[str] = None, - word_timestamps: bool = False, - without_timestamps: bool = True, - ) -> Tuple[Iterable[Segment], TranscriptionInfo]: - """transcribe audio in chunks in batched fashion and return with language info. - - Arguments: - audio: audio file as numpy array/path for batched transcription. - vad_segments: Optionally provide list of dictionaries each containing "start", "end", - and "segments" keys. - "start" and "end" keys specify the start and end of the voiced region within - 30 sec boundary. An additional key "segments" contains all the start - and end of voiced regions within that 30sec boundary as a list of tuples. - If no vad_segments specified, it uses internal vad model automatically segment them. - batch_size: the maximum number of parallel requests to model for decoding. - language: The language spoken in the audio. - task: either "transcribe" or "translate". - log_progress: whether to show progress bar or not. - beam_size: Beam size to use for decoding. - best_of: Number of candidates when sampling with non-zero temperature. - patience: Beam search patience factor. - length_penalty: Exponential length penalty constant. - repetition_penalty: Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). - temperature: Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - compression_ratio_threshold: If the gzip compression ratio is above this value, - treat as failed. - log_prob_threshold: If the average log probability over sampled tokens is - below this value, treat as failed. - log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - whereas log_prob_threshold also looks for appropriate no_speech_threshold value. - This value should be less than log_prob_threshold. - no_speech_threshold: If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - initial_prompt: Optional text string or iterable of token ids to provide as a - prompt for the first window. - prefix: Optional text to provide as a prefix for the first window. - suppress_blank: Suppress blank outputs at the beginning of the sampling. - suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in `tokenizer.non_speech_tokens()`. - prepend_punctuations: If word_timestamps is True, merge these punctuation symbols - with the next word - append_punctuations: If word_timestamps is True, merge these punctuation symbols - with the previous word - max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - hotwords: - Hotwords/hint phrases to the model. Has no effect if prefix is not None. - word_timestamps: Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - Set as False. - without_timestamps: Only sample text tokens. - - Static params: (Fixed for batched version) - max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. - multilingual: If True, perform transcription on multilingual videos. Set as False. - output_language: Valid only if multilingual is set to True. - Specifies the string representing the output language. One of - 'en' (English) or 'hybrid' (code-switched transcription). set as None. - condition_on_previous_text: If True, the previous output of the model is provided - as a prompt for the next window; disabling may make the text inconsistent across - windows, but the model becomes less prone to getting stuck in a failure loop, - such as repetition looping or timestamps going out of sync. Set as False - prompt_reset_on_temperature: Resets prompt if temperature is above this value. - Arg has effect only if condition_on_previous_text is True. Set at 0.5 - #TODO: support "hallucination_silence_threshold" when "word_timestamps=True" - hallucination_silence_threshold: Optional[float] - When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected. set as None. - clip_timestamps: - Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to - process. The last end timestamp defaults to the end of the file. Set as "0". - - unused: - language_detection_threshold: If the maximum probability of the language tokens is - higher than this value, the language is detected. - language_detection_segments: Number of segments to consider for the language detection. - vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available - parameters and default values in the class `VadOptions`). - chunk_length: The length of audio segments. If it is not None, it will overwrite the - default chunk_length of the FeatureExtractor. - - - Returns: - A tuple with: - - - a generator over transcribed batched segments. - - an instance of TranscriptionInfo. - """ - - sampling_rate = self.model.feature_extractor.sampling_rate - - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - elif not isinstance(audio, torch.Tensor): - audio = decode_audio(audio, sampling_rate=sampling_rate) - duration = audio.shape[0] / sampling_rate - - # if no segment split is provided, use vad_model and generate segments - if not vad_segments: - # run the audio if it is less than 30 sec even without vad_segments - if self.use_vad_model: - vad_segments = self.vad_model( - { - "waveform": audio.unsqueeze(0), - "sample_rate": 16000, - } - ) - vad_segments = merge_chunks( - vad_segments, - self.chunk_length, - onset=self.vad_onset, - offset=self.vad_offset, - ) - elif duration < self.chunk_length: - vad_segments = [ - {"start": 0.0, "end": duration, "segments": [(0.0, duration)]} - ] - else: - raise RuntimeError( - "No vad segments found. Set 'use_vad_model' to True while loading the model" - ) - if self.model.model.is_multilingual: - language = language or self.preset_language - elif language != "en": - if language is not None: - self.model.logger.warning( - f"English-only model is used, but {language} language is" - "chosen, setting language to 'en'." - ) - language = "en" - - ( - language, - language_probability, - task, - all_language_probs, - ) = self.get_language_and_tokenizer(audio, task, language) - - duration_after_vad = sum( - segment["end"] - segment["start"] for segment in vad_segments - ) - - # batched options: see the difference with default options in WhisperModel - batched_options = TranscriptionOptions( - beam_size=beam_size, - best_of=best_of, - patience=patience, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, - no_speech_threshold=no_speech_threshold, - compression_ratio_threshold=compression_ratio_threshold, - temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] - ), - initial_prompt=initial_prompt, - prefix=prefix, - suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - max_new_tokens=max_new_tokens, - hotwords=hotwords, - word_timestamps=word_timestamps, - hallucination_silence_threshold=None, - condition_on_previous_text=False, - clip_timestamps="0", - prompt_reset_on_temperature=0.5, - multilingual=False, - output_language=None, - without_timestamps=without_timestamps, - max_initial_timestamp=0.0, - ) - - info = TranscriptionInfo( - language=language, - language_probability=language_probability, - duration=duration, - duration_after_vad=duration_after_vad, - transcription_options=batched_options, - vad_options=None, - all_language_probs=all_language_probs, - ) - - audio_segments, segments_metadata = self.audio_split( - audio, vad_segments, sampling_rate - ) - to_cpu = ( - self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 - ) - audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor( - padding=0 - ) - features = torch.stack( - [ - self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[ - ..., : self.model.feature_extractor.nb_max_frames - ] - for audio_segment in audio_segments - ] - ) - - segments = self._batched_segments_generator( - features, - segments_metadata, - batch_size, - batched_options, - log_progress, - ) - - return segments, info - - def _batched_segments_generator( - self, features, segments_metadata, batch_size, options, log_progress - ): - pbar = tqdm(total=len(features), disable=not log_progress, position=0) - seg_idx = 0 - for i in range(0, len(features), batch_size): - results = self.forward( - features[i : i + batch_size], - segments_metadata[i : i + batch_size], - **options._asdict(), - ) - - for result in results: - for segment in result: - seg_idx += 1 - yield Segment( - seek=int(result[-1]["end"] * self.model.frames_per_second), - id=seg_idx, - text=segment["text"], - start=round(segment["start"], 3), - end=round(segment["end"], 3), - words=( - None - if not options.word_timestamps - else [Word(**word) for word in segment["words"]] - ), - tokens=segment["tokens"], - avg_logprob=segment["avg_logprob"], - no_speech_prob=segment["no_speech_prob"], - compression_ratio=segment["compression_ratio"], - ) - - pbar.update(1) - - pbar.close() - # revert the tokenizer if multilingual inference is enabled - if self.preset_language is None: - self.tokenizer = None - self.last_speech_timestamp = 0.0 - - class WhisperModel: def __init__( self, @@ -587,7 +89,7 @@ def __init__( device: str = "auto", device_index: Union[int, List[int]] = 0, compute_type: str = "default", - cpu_threads: int = 16, + cpu_threads: int = 0, num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, @@ -639,12 +141,10 @@ def __init__( local_files_only=local_files_only, cache_dir=download_root, ) - self.device = device - # set the random seed to make sure consistency across runs - ctranslate2.set_random_seed(42) + self.model = ctranslate2.models.Whisper( model_path, - device=self.device, + device=device, device_index=device_index, compute_type=compute_type, intra_threads=cpu_threads, @@ -663,19 +163,15 @@ def __init__( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) - self.feature_extractor = FeatureExtractor( - **self.feat_kwargs, device=self.device - ) - self.input_stride = 2 - self.num_samples_per_token = ( - self.feature_extractor.hop_length * self.input_stride - ) + self.feature_extractor = FeatureExtractor(**self.feat_kwargs) + self.num_samples_per_token = self.feature_extractor.hop_length * 2 self.frames_per_second = ( self.feature_extractor.sampling_rate // self.feature_extractor.hop_length ) self.tokens_per_second = ( self.feature_extractor.sampling_rate // self.num_samples_per_token ) + self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -704,7 +200,7 @@ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: def transcribe( self, - audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], + audio: Union[str, BinaryIO, np.ndarray], language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, @@ -723,7 +219,6 @@ def transcribe( ], compression_ratio_threshold: Optional[float] = 2.4, log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, prompt_reset_on_temperature: float = 0.5, @@ -736,8 +231,6 @@ def transcribe( word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", - multilingual: bool = False, - output_language: Optional[str] = None, vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, @@ -770,9 +263,6 @@ def transcribe( treat as failed. log_prob_threshold: If the average log probability over sampled tokens is below this value, treat as failed. - log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - wheras log_prob_threshold also looks for appropriate no_speech_threshold value. - This value should be less than log_prob_threshold. no_speech_threshold: If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `log_prob_threshold`, consider the segment as silent. @@ -787,7 +277,7 @@ def transcribe( prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in `tokenizer.non_speech_tokens()`. + of symbols as defined in `tokenizer.non_speech_tokens()` without_timestamps: Only sample text tokens. max_initial_timestamp: The initial timestamp cannot be later than this. word_timestamps: Extract word-level timestamps using the cross-attention pattern @@ -796,12 +286,6 @@ def transcribe( with the next word append_punctuations: If word_timestamps is True, merge these punctuation symbols with the previous word - multilingual: If True, perform transcription on multilingual videos - and return the transcript based - on the 'output_language' flag. - output_language: Valid only if multilingual is set to True. - Specifies the string representing the output language. One of - 'en' (English) or 'hybrid' (code-switched transcription). vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio without speech. This step is using the Silero VAD model https://github.com/snakers4/silero-vad. @@ -829,12 +313,9 @@ def transcribe( - a generator over transcribed segments - an instance of TranscriptionInfo """ - sampling_rate = self.feature_extractor.sampling_rate - if isinstance(audio, np.ndarray): - audio = torch.from_numpy(audio) - elif not isinstance(audio, torch.Tensor): + if not isinstance(audio, np.ndarray): audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate @@ -874,22 +355,11 @@ def transcribe( else: speech_chunks = None - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor( - audio, chunk_length=chunk_length, to_cpu=to_cpu - ) + features = self.feature_extractor(audio, chunk_length=chunk_length) encoder_output = None all_language_probs = None - # setting output_language for multilingual videos - if multilingual: - if output_language is None: - output_language = "en" - elif output_language not in ["en", "hybrid"]: - raise ValueError("Output language needs to be one of 'en'/'hybrid'.") - - # detecting the language if not provided if language is None: if not self.model.is_multilingual: language = "en" @@ -982,7 +452,6 @@ def transcribe( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, condition_on_previous_text=condition_on_previous_text, @@ -1003,8 +472,6 @@ def transcribe( word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, - multilingual=multilingual, - output_language=output_language, max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, @@ -1027,88 +494,9 @@ def transcribe( ) return segments, info - def _split_segments_by_timestamps( - self, - tokenizer: Tokenizer, - tokens: List[int], - time_offset: float, - segment_size: int, - segment_duration: float, - seek: int, - ) -> List[List[int]]: - current_segments = [] - single_timestamp_ending = ( - len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] - ) - - consecutive_timestamps = [ - i - for i in range(len(tokens)) - if i > 0 - and tokens[i] >= tokenizer.timestamp_begin - and tokens[i - 1] >= tokenizer.timestamp_begin - ] - - if len(consecutive_timestamps) > 0: - slices = list(consecutive_timestamps) - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin - end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin - start_time = ( - time_offset + start_timestamp_position * self.time_precision - ) - end_time = time_offset + end_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=start_time, - end=end_time, - tokens=sliced_tokens, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_position = ( - tokens[last_slice - 1] - tokenizer.timestamp_begin - ) - seek += last_timestamp_position * self.input_stride - - else: - duration = segment_duration - timestamps = [ - token for token in tokens if token >= tokenizer.timestamp_begin - ] - if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: - last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin - duration = last_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=time_offset, - end=time_offset + duration, - tokens=tokens, - ) - ) - - seek += segment_size - - return current_segments, seek, single_timestamp_ending - def generate_segments( self, - features: torch.Tensor, + features: np.ndarray, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, @@ -1191,28 +579,6 @@ def generate_segments( ) previous_tokens = all_tokens[prompt_reset_since:] - - if encoder_output is None: - encoder_output = self.encode(segment) - - # Perform language detection at every segment to update task based on output language, - # if the language is english, task is transcribe, - # else the task is translate to english (default) - # or transcribe if 'output_language' is 'hybrid'. - if options.multilingual: - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] - if options.output_language == "en" and language != "en": - task = "translate" - else: - task = "transcribe" - - # Update tokenizer based on task and language - tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>") - tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) - tokenizer.language_code = language - # Update prompt based on task and language prompt = self.get_prompt( tokenizer, previous_tokens, @@ -1249,18 +615,6 @@ def generate_segments( options.no_speech_threshold, ) - # Skip if the logprob is very low (below the threshold value), - # despite no_speech_prob being low (ex: Too ambiguous outputs) - if options.log_prob_low_threshold: - if avg_logprob < options.log_prob_low_threshold: - should_skip = True - self.logger.debug( - "log prob low threshold is met (%f > %f)", - avg_logprob, - options.log_prob_low_threshold, - ) - - if should_skip: # fast-forward to the next segment boundary seek += segment_size continue @@ -1268,6 +622,7 @@ def generate_segments( tokens = result.sequences_ids[0] previous_seek = seek + current_segments = [] # anomalous words are very long/short/improbable def word_anomaly_score(word: dict) -> float: @@ -1293,22 +648,83 @@ def is_segment_anomaly(segment: Optional[dict]) -> bool: def next_words_segment(segments: List[dict]) -> Optional[dict]: return next((s for s in segments if s["words"]), None) - ( - current_segments, - seek, - single_timestamp_ending, - ) = self._split_segments_by_timestamps( - tokenizer=tokenizer, - tokens=tokens, - time_offset=time_offset, - segment_size=segment_size, - segment_duration=segment_duration, - seek=seek, + single_timestamp_ending = ( + len(tokens) >= 2 + and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] ) + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0] - tokenizer.timestamp_begin + ) + end_timestamp_position = ( + sliced_tokens[-1] - tokenizer.timestamp_begin + ) + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = ( + time_offset + end_timestamp_position * self.time_precision + ) + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + if options.word_timestamps: self.add_word_timestamps( - [current_segments], + current_segments, tokenizer, encoder_output, segment_size, @@ -1316,6 +732,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: options.append_punctuations, last_speech_timestamp=last_speech_timestamp, ) + if not single_timestamp_ending: last_word_end = get_end(current_segments) if last_word_end is not None and last_word_end > time_offset: @@ -1372,6 +789,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: last_word_end = get_end(current_segments) if last_word_end is not None: last_speech_timestamp = last_word_end + for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -1421,13 +839,12 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: prompt_reset_since = len(all_tokens) - def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - if features.ndim == 2: - features = features.unsqueeze(0) + features = np.expand_dims(features, 0) features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) @@ -1606,127 +1023,115 @@ def add_word_timestamps( prepend_punctuations: str, append_punctuations: str, last_speech_timestamp: float, - ) -> float: + ) -> None: if len(segments) == 0: return - text_tokens = [] - text_tokens_per_segment = [] - for segment in segments: - segment_tokens = [ - [token for token in subsegment["tokens"] if token < tokenizer.eot] - for subsegment in segment - ] - text_tokens.append(list(itertools.chain.from_iterable(segment_tokens))) - text_tokens_per_segment.append(segment_tokens) + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments + ] - alignments = self.find_alignment( + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) + alignment = self.find_alignment( tokenizer, text_tokens, encoder_output, num_frames ) - median_max_durations = [] - for alignment in alignments: - word_durations = np.array( - [word["end"] - word["start"] for word in alignment] - ) - word_durations = word_durations[word_durations.nonzero()] - median_duration = ( - np.median(word_durations) if len(word_durations) > 0 else 0.0 - ) - median_duration = min(0.7, float(median_duration)) - max_duration = median_duration * 2 + word_durations = np.array([word["end"] - word["start"] for word in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = ( + segments[0]["seek"] + * self.feature_extractor.hop_length + / self.feature_extractor.sampling_rate + ) - # hack: truncate long words at sentence boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(word_durations) > 0: - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries - # are not longer than twice the median word duration. - for i in range(1, len(alignment)): - if alignment[i]["end"] - alignment[i]["start"] > max_duration: - if alignment[i]["word"] in sentence_end_marks: - alignment[i]["end"] = alignment[i]["start"] + max_duration - elif alignment[i - 1]["word"] in sentence_end_marks: - alignment[i]["start"] = alignment[i]["end"] - max_duration - - merge_punctuations(alignment, prepend_punctuations, append_punctuations) - median_max_durations.append((median_duration, max_duration)) - - for segment_idx, segment in enumerate(segments): - word_index = 0 - time_offset = segment[0]["start"] - median_duration, max_duration = median_max_durations[segment_idx] - for subsegment_idx, subsegment in enumerate(segment): - saved_tokens = 0 - words = [] - - while word_index < len(alignments[segment_idx]) and saved_tokens < len( - text_tokens_per_segment[segment_idx][subsegment_idx] - ): - timing = alignments[segment_idx][word_index] - - if timing["word"]: - words.append( - dict( - word=timing["word"], - start=round(time_offset + timing["start"], 2), - end=round(time_offset + timing["end"], 2), - probability=timing["probability"], - ) - ) + word_index = 0 - saved_tokens += len(timing["tokens"]) - word_index += 1 - - # hack: truncate long words at segment boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(words) > 0: - # ensure the first and second word after a pause is not longer than - # twice the median word duration. - if words[0][ - "end" - ] - last_speech_timestamp > median_duration * 4 and ( - words[0]["end"] - words[0]["start"] > max_duration - or ( - len(words) > 1 - and words[1]["end"] - words[0]["start"] > max_duration * 2 - ) - ): - if ( - len(words) > 1 - and words[1]["end"] - words[1]["start"] > max_duration - ): - boundary = max( - words[1]["end"] / 2, words[1]["end"] - max_duration - ) - words[0]["end"] = words[1]["start"] = boundary - words[0]["start"] = max(0, words[0]["end"] - max_duration) + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] - # prefer the segment-level start timestamp if the first word is too long. - if ( - subsegment["start"] < words[0]["end"] - and subsegment["start"] - 0.5 > words[0]["start"] - ): - words[0]["start"] = max( - 0, - min(words[0]["end"] - median_duration, subsegment["start"]), + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], ) - else: - subsegment["start"] = words[0]["start"] + ) - # prefer the segment-level end timestamp if the last word is too long. + saved_tokens += len(timing["tokens"]) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): if ( - subsegment["end"] > words[-1]["start"] - and subsegment["end"] + 0.5 < words[-1]["end"] + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration ): - words[-1]["end"] = max( - words[-1]["start"] + median_duration, subsegment["end"] + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration ) - else: - subsegment["end"] = words[-1]["end"] + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if ( + segment["end"] > words[-1]["start"] + and segment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) + else: + segment["end"] = words[-1]["end"] + + last_speech_timestamp = segment["end"] - last_speech_timestamp = subsegment["end"] - segments[segment_idx][subsegment_idx]["words"] = words - return last_speech_timestamp + segment["words"] = words def find_alignment( self, @@ -1739,332 +1144,51 @@ def find_alignment( if len(text_tokens) == 0: return [] - results = self.model.align( + result = self.model.align( encoder_output, tokenizer.sot_sequence, - text_tokens, + [text_tokens], num_frames, median_filter_width=median_filter_width, - ) - return_list = [] - for result, text_token in zip(results, text_tokens): - text_token_probs = result.text_token_probs - alignments = result.alignments - text_indices = np.array([pair[0] for pair in alignments]) - time_indices = np.array([pair[1] for pair in alignments]) - - words, word_tokens = tokenizer.split_to_word_tokens( - text_token + [tokenizer.eot] - ) - if len(word_tokens) <= 1: - # return on eot only - # >>> np.pad([], (1, 0)) - # array([0.]) - # This results in crashes when we lookup jump_times with float, like - # IndexError: arrays used as indices must be of integer (or boolean) type - return [] - word_boundaries = np.pad( - np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) - ) - if len(word_boundaries) <= 1: - return [] - - jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( - bool - ) - jump_times = time_indices[jumps] / self.tokens_per_second - start_times = jump_times[word_boundaries[:-1]] - end_times = jump_times[word_boundaries[1:]] - word_probabilities = [ - np.mean(text_token_probs[i:j]) - for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) - ] - - return_list.append( - [ - dict( - word=word, - tokens=tokens, - start=start, - end=end, - probability=probability, - ) - for word, tokens, start, end, probability in zip( - words, word_tokens, start_times, end_times, word_probabilities - ) - ] - ) - return return_list - - def generate_segment_batched( - self, - features: torch.Tensor, - tokenizer: Tokenizer, - options: dict, - ): - batch_size = features.shape[0] - all_tokens = [] - prompt_reset_since = 0 + )[0] - if options["initial_prompt"] is not None: - initial_prompt = " " + options["initial_prompt"].strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) - previous_tokens = all_tokens[prompt_reset_since:] - prompt = self.get_prompt( - tokenizer, - previous_tokens, - without_timestamps=options["without_timestamps"], - prefix=options["prefix"], - ) + text_token_probs = result.text_token_probs - encoder_output = self.encode(features) + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) - result = self.model.generate( - encoder_output, - [prompt] * batch_size, - beam_size=options["beam_size"], - patience=options["patience"], - length_penalty=options["length_penalty"], - max_length=self.max_length, - suppress_blank=options["suppress_blank"], - suppress_tokens=options["suppress_tokens"], - return_scores=True, - return_no_speech_prob=True, + words, word_tokens = tokenizer.split_to_word_tokens( + text_tokens + [tokenizer.eot] ) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + if len(word_boundaries) <= 1: + return [] - output = [] - for res in result: - output.append({}) - # return scores - seq_len = len(res.sequences_ids[0]) - cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) - output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) - - # return no speech prob - output[-1]["no_speech_prob"] = res.no_speech_prob - output[-1]["tokens"] = res.sequences_ids[0] - - return encoder_output, output - - def detect_language(self, audio: torch.Tensor): - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ - :, : self.feature_extractor.nb_max_frames + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] - encoder_output = self.encode(segment) - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] - self.logger.info( - f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." - ) - all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] - return language, language_probability, all_language_probs - def detect_language_multi_segment( - self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None - ): - """ - Detect language based on N highly-confident segments of a language. - """ - # The threshold is used to decide if the audio is silence or not. - # The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent, - # the audio is considered as silence. - if not params: - params = { - "multilingual": False, - "speech_percentage_threshold": 0.02, - "language_detection_segments": 4, - "vad_filter": True, - "vad_min_silence_duration": 2500, - "language_threshold": 0.7, - } - - if params.get("multilingual", False): - logging.warning( - "lang_id is not supported for multilingual audios, detecting the major language." + return [ + dict( + word=word, tokens=tokens, start=start, end=end, probability=probability ) - - speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02) - language_threshold = params.get("language_threshold", 0.7) - num_detection_segments = params.get("language_detection_segments", 4) - vad_filter_enabled = params.get("vad_filter", True) - vad_params = dict( - min_silence_duration_ms=params.get("vad_min_silence_duration", 2500) - ) - - if vad_filter_enabled: - vad_params = VadOptions(**vad_params) - - # decode audio if it is not decoded already - sampling_rate = self.feature_extractor.sampling_rate - if not isinstance(audio, torch.Tensor): - audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate) - - # calculate duration of audio as number of seconds - # audio.shape[0] is the number of samples in the audio - # sampling_rate is the number of samples per second - # if we divide the number of samples by the number of samples per second, - # we get the duration in seconds - duration = audio.shape[0] / sampling_rate - - # Check if vad is enabled, and collect voiced segments - if vad_filter_enabled: - # get chunks of audio that contain speech - speech_chunks = get_speech_timestamps(audio, vad_params) - # merge chunks of audio that contain speech into a single array - audio = collect_chunks(audio, speech_chunks) - - # calculate new duration of audio without silence - duration_vad = audio.shape[0] / sampling_rate - - logging.debug( - f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio" + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities ) - - # if the audio after VAD is less than 2% of the original audio, consider it as silence - if duration_vad / duration < speech_percentage_threshold: - return {"language_code": None, "language_confidence": 1.0} - - # update duration to be the duration after VAD - duration = duration_vad - - # if the duration of the audio is less than 1 second, consider it as silence - if duration < 1.0: - return {"language_code": None, "language_confidence": 1.0} - - # number of feature frames in 30 seconds of audio is 3000 - nb_max_frames = self.feature_extractor.nb_max_frames - - # extract features from audio with padding (default) - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = self.feature_extractor(audio, to_cpu=to_cpu) - - # number of segments in the audio - num_segments = features.shape[-1] // nb_max_frames - # more number of segments than possible with the duration of file - if num_detection_segments > num_segments: - logging.warning( - f"Lang ID: Can not have more segments, setting {num_segments} segments." - ) - num_detection_segments = num_segments - - # create a list of indices to randomly select segments from - indices = list(range(num_detection_segments)) - - # fix seed to get deterministic results - random.seed(0) - random.shuffle(indices) - - detected_languages = [] - all_language_probabilities = defaultdict(list) - confident_language_probabilities = defaultdict(list) - num_confident_segments_per_language = defaultdict(int) - - # Iterate over the randomly selected indices of the segments. - # - # For each segment, extract features and detect language. - # - # If the language is confident, add it to the list of confident segments for that language. - # - # If the number of confident segments for a language - # is greater than or equal to the number of detection segments, - # return the language and the average probability of the language. - # - # If we are unable to get sufficient number of confident predcitions, - # return the most frequently detected language with maximum probability. - # - # We need to get sufficient number of confident predictions per language, not in total. - - for i in indices: - segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] - try: - encoder_output = self.encode(segment_features) - results = self.model.detect_language(encoder_output)[0] - - except ValueError as e: # or RuntimeError - logging.error(f"Inference error:{e}") - - # results is the list of classes (languages) and their probabilities (descending), - # for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...] - - # take top language token and probability - # and parse language token to strip out markers - # for eg: '<|de|>' -> 'de' - - language_token = results[0][0] - language = language_token[2:-2] - - language_probability = results[0][1] - - detected_languages.append(language) - all_language_probabilities[language].append(language_probability) - - # only consider if the language prediction is confident - if language_probability > language_threshold: - num_confident_segments_per_language[language] += 1 - - # Add language and probability to the list of languages when it is confident - confident_language_probabilities[language].append(language_probability) - - # return the language when sufficient number of confident segments is achieved - if ( - num_confident_segments_per_language[language] - >= num_detection_segments - ): - # Considering the average probability of only confident segments - mean = sum(confident_language_probabilities[language]) / len( - confident_language_probabilities[language] - ) - return { - "language_code": language, - "language_confidence": mean, - } - - # if we are unable to get sufficient number of confident predictions, - # return the most frequently detected language. - # if there is a tie, return the one with maximum average probability. - counter = Counter(detected_languages) - - # Define the key function to select frequent language with attached probabilities - def key_func(language): - # Calculate the frequency of the language - frequency = counter[language] - - # Calculate the average probability of the language - prob_avg = sum(all_language_probabilities[language]) / len( - all_language_probabilities[language] - ) - - return frequency, prob_avg - - if detected_languages: - # Use the key function to find the language with maximum frequency and probability - max_language = max(detected_languages, key=key_func) - max_probability = sum(all_language_probabilities[max_language]) / len( - all_language_probabilities[max_language] - ) - - # Do additional checks for silence for non-confident case - # calculate RMS amplitude and DC offset - dc_offset = audio.mean() - audio_minus_dc_offset = audio - dc_offset - is_silent = ( - torch.all(audio.abs() < 0.01) - or torch.sqrt(torch.mean(audio_minus_dc_offset**2)) < 0.01 - ) - - if is_silent: - return {"language_code": None, "language_confidence": 1.0} - - return { - "language_code": max_language, - "language_confidence": max_probability, - } - - # Language is not detected for any segment and none of prev conditions met - return {"language_code": None, "language_confidence": 1.0} + ] def restore_speech_timestamps( @@ -2102,12 +1226,9 @@ def restore_speech_timestamps( yield segment -def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: - segment = segment.contiguous() - segment = ctranslate2.StorageView.from_array( - segment if segment.is_cuda else segment.numpy() - ) # torch cpu tensors don't implement __array_interface__ - # https://github.com/pytorch/pytorch/issues/51156 +def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: + segment = np.ascontiguousarray(segment) + segment = ctranslate2.StorageView.from_array(segment) return segment @@ -2151,11 +1272,9 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if previous["word"].startswith(" ") and previous["word"].strip() in prepended: # prepend it to the following word following["word"] = previous["word"] + following["word"] - if "tokens" in alignment[0].keys(): - following["tokens"] = previous["tokens"] + following["tokens"] - previous["tokens"] = [] + following["tokens"] = previous["tokens"] + following["tokens"] previous["word"] = "" - + previous["tokens"] = [] else: j = i i -= 1 @@ -2169,11 +1288,9 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if not previous["word"].endswith(" ") and following["word"] in appended: # append it to the previous word previous["word"] = previous["word"] + following["word"] - if "tokens" in alignment[0].keys(): - previous["tokens"] = previous["tokens"] + following["tokens"] - following["tokens"] = [] + previous["tokens"] = previous["tokens"] + following["tokens"] following["word"] = "" - + following["tokens"] = [] else: i = j j += 1 diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 3881fd81..99dfb401 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -2,17 +2,9 @@ import functools import os -from abc import ABC -from collections.abc import Callable -from typing import List, NamedTuple, Optional, Union +from typing import List, NamedTuple, Optional import numpy as np -import torch - -from pyannote.audio.core.io import AudioFile -from pyannote.audio.pipelines import VoiceActivityDetection -from pyannote.audio.pipelines.utils import PipelineModel -from pyannote.core import Annotation, Segment, SlidingWindowFeature from faster_whisper.utils import get_assets_path @@ -43,7 +35,7 @@ class VadOptions(NamedTuple): def get_speech_timestamps( - audio: torch.Tensor, + audio: np.ndarray, vad_options: Optional[VadOptions] = None, **kwargs, ) -> List[dict]: @@ -184,12 +176,12 @@ def get_speech_timestamps( return speeches -def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor: +def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: """Collects and concatenates audio chunks.""" if not chunks: - return torch.tensor([], dtype=torch.float32) + return np.array([], dtype=np.float32) - return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) class SpeechTimestampsMap: @@ -284,313 +276,3 @@ def __call__(self, x, state, context, sr: int): context = x[..., -64:] return out, state, context - - -# BSD 2-Clause License - -# Copyright (c) 2024, Max Bain - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -# The code below is copied from whisper-x (https://github.com/m-bain/whisperX) -# and adapted for faster_whisper. -class SegmentX: - def __init__(self, start, end, speaker=None): - self.start = start - self.end = end - self.speaker = speaker - - -class VoiceActivitySegmentation(VoiceActivityDetection, ABC): - """Pipeline wrapper class for Voice Activity Segmentation based on VAD scores.""" - - def __init__( - self, - segmentation: PipelineModel = "pyannote/segmentation", - device: Optional[Union[str, torch.device]] = None, - fscore: bool = False, - use_auth_token: Optional[str] = None, - **inference_kwargs, - ): - """Initialize the pipeline with the model name and the optional device. - - Args: - dict parameters of VoiceActivityDetection class from pyannote: - segmentation (PipelineModel): Loaded model name. - device (torch.device or None): Device to perform the segmentation. - fscore (bool): Flag indicating whether to compute F-score during inference. - use_auth_token (str or None): Optional authentication token for model access. - inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline. - """ - super().__init__( - segmentation=segmentation, - device=device, - fscore=fscore, - use_auth_token=use_auth_token, - **inference_kwargs, - ) - - def apply( - self, file: AudioFile, hook: Optional[Callable] = None - ) -> SlidingWindowFeature: - """Apply voice activity detection on the audio file. - - Args: - file (AudioFile): Processed file. - hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file) - - Returns: - segmentations (SlidingWindowFeature): Voice activity segmentation. - """ - # setup hook (e.g. for debugging purposes) - hook = self.setup_hook(file, hook=hook) - - # apply segmentation model if needed - # output shape is (num_chunks, num_frames, 1) - if self.training: - if self.CACHED_SEGMENTATION in file: - segmentations = file[self.CACHED_SEGMENTATION] - else: - segmentations = self._segmentation(file) - file[self.CACHED_SEGMENTATION] = segmentations - else: - segmentations: SlidingWindowFeature = self._segmentation(file) - - return segmentations - - -class BinarizeVadScores: - """Binarize detection scores using hysteresis thresholding. - - Reference: - Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of - RNN-based Voice Activity Detection", InterSpeech 2015. - - Modified by Max Bain to include WhisperX's min-cut operation - https://arxiv.org/abs/2303.00747 - - """ - - def __init__( - self, - onset: float = 0.5, - offset: Optional[float] = None, - min_duration_on: float = 0.0, - min_duration_off: float = 0.0, - pad_onset: float = 0.0, - pad_offset: float = 0.0, - max_duration: float = float("inf"), - ): - """Initializes the parameters for Binarizing the VAD scores. - - Args: - onset (float, optional): - Onset threshold. Defaults to 0.5. - offset (float, optional): - Offset threshold. Defaults to `onset`. - min_duration_on (float, optional): - Remove active regions shorter than that many seconds. Defaults to 0s. - min_duration_off (float, optional): - Fill inactive regions shorter than that many seconds. Defaults to 0s. - pad_onset (float, optional): - Extend active regions by moving their start time by that many seconds. - Defaults to 0s. - pad_offset (float, optional): - Extend active regions by moving their end time by that many seconds. - Defaults to 0s. - max_duration (float): - The maximum length of an active segment. - """ - super().__init__() - - self.onset = onset - self.offset = offset or onset - - self.pad_onset = pad_onset - self.pad_offset = pad_offset - - self.min_duration_on = min_duration_on - self.min_duration_off = min_duration_off - - self.max_duration = max_duration - - def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation: - """Extract active regions from VAD scores. - - Args: - scores (SlidingWindowFeature): Detection scores. - - Returns: - active (Annotation): Active regions. - """ - num_frames, num_classes = scores.data.shape - frames = scores.sliding_window - timestamps = [frames[i].middle for i in range(num_frames)] - # annotation meant to store 'active' regions - active = Annotation() - for k, k_scores in enumerate(scores.data.T): - label = k if scores.labels is None else scores.labels[k] - - # initial state - start = timestamps[0] - is_active = k_scores[0] > self.onset - curr_scores = [k_scores[0]] - curr_timestamps = [start] - t = start - # optionally add `strict=False` for python 3.10 or later - for t, y in zip(timestamps[1:], k_scores[1:]): - # currently active - if is_active: - curr_duration = t - start - if curr_duration > self.max_duration: - search_after = len(curr_scores) // 2 - # divide segment - min_score_div_idx = search_after + np.argmin( - curr_scores[search_after:] - ) - min_score_t = curr_timestamps[min_score_div_idx] - region = Segment( - start - self.pad_onset, min_score_t + self.pad_offset - ) - active[region, k] = label - start = curr_timestamps[min_score_div_idx] - curr_scores = curr_scores[min_score_div_idx + 1 :] - curr_timestamps = curr_timestamps[min_score_div_idx + 1 :] - # switching from active to inactive - elif y < self.offset: - region = Segment(start - self.pad_onset, t + self.pad_offset) - active[region, k] = label - start = t - is_active = False - curr_scores = [] - curr_timestamps = [] - curr_scores.append(y) - curr_timestamps.append(t) - # currently inactive - else: - # switching from inactive to active - if y > self.onset: - start = t - is_active = True - - # if active at the end, add final region - if is_active: - region = Segment(start - self.pad_onset, t + self.pad_offset) - active[region, k] = label - - return active - - def __call__(self, scores: SlidingWindowFeature) -> Annotation: - """Binarize detection scores. - - Args: - scores (SlidingWindowFeature): Detection scores. - - Returns: - active (Annotation): Binarized scores. - """ - active = self.__get_active_regions(scores) - # because of padding, some active regions might be overlapping: merge them. - # also: fill same speaker gaps shorter than min_duration_off - if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: - if self.max_duration < float("inf"): - raise NotImplementedError("This would break current max_duration param") - active = active.support(collar=self.min_duration_off) - - # remove tracks shorter than min_duration_on - if self.min_duration_on > 0: - for segment, track in list(active.itertracks()): - if segment.duration < self.min_duration_on: - del active[segment, track] - - return active - - -def merge_chunks( - segments, - chunk_length, - onset: float = 0.5, - offset: Optional[float] = None, - edge_padding: float = 0.1, -): - """ - Merge operation described in whisper-x paper - """ - curr_end = 0 - merged_segments = [] - seg_idxs = [] - speaker_idxs = [] - - assert chunk_length > 0 - binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset) - segments = binarize(segments) - segments_list = [] - for speech_turn in segments.get_timeline(): - segments_list.append( - SegmentX( - max(0.0, speech_turn.start - edge_padding), - speech_turn.end + edge_padding, - "UNKNOWN", - ) - ) # 100ms edge padding to account for edge errors - - if len(segments_list) == 0: - print("No active speech found in audio") - return [] - - # Make sur the starting point is the start of the segment. - curr_start = segments_list[0].start - - for idx, seg in enumerate(segments_list): - # if any segment start timing is less than previous segment end timing, - # reset the edge padding. Similarly for end timing. - if idx > 0: - if seg.start < segments_list[idx - 1].end: - seg.start += edge_padding - if idx < len(segments_list) - 1: - if seg.end > segments_list[idx + 1].start: - seg.end -= edge_padding - - if seg.end - curr_start > chunk_length and curr_end - curr_start > 0: - merged_segments.append( - { - "start": curr_start, - "end": curr_end, - "segments": seg_idxs, - } - ) - curr_start = seg.start - seg_idxs = [] - speaker_idxs = [] - curr_end = seg.end - seg_idxs.append((seg.start, seg.end)) - speaker_idxs.append(seg.speaker) - # add final - merged_segments.append( - { - "start": curr_start, - "end": curr_end, - "segments": seg_idxs, - } - ) - return merged_segments diff --git a/requirements.txt b/requirements.txt index e0a3afba..b1497ab4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,5 @@ +av>=11.0,<13 ctranslate2>=4.0,<5 huggingface_hub>=0.13 tokenizers>=0.13,<1 -onnxruntime>=1.14,<2 -pyannote-audio>=3.1.1 -torch>=2.1.1 -torchaudio>=2.1.2 -tqdm \ No newline at end of file +onnxruntime>=1.14,<2 diff --git a/tests/conftest.py b/tests/conftest.py index 0c0f4248..1a1ee1d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,3 @@ def data_dir(): @pytest.fixture def jfk_path(data_dir): return os.path.join(data_dir, "jfk.flac") - - -@pytest.fixture -def physcisworks_path(data_dir): - return os.path.join(data_dir, "physicsworks.wav") diff --git a/tests/data/physicsworks.wav b/tests/data/physicsworks.wav deleted file mode 100644 index 885b6c1c..00000000 Binary files a/tests/data/physicsworks.wav and /dev/null differ diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 96eb68c3..7fa27b11 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,6 +1,6 @@ import os -from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio +from faster_whisper import WhisperModel, decode_audio from faster_whisper.tokenizer import Tokenizer from faster_whisper.transcribe import get_suppressed_tokens @@ -39,50 +39,6 @@ def test_transcribe(jfk_path): assert segment.text == "".join(word.word for word in segment.words) assert segment.start == segment.words[0].start assert segment.end == segment.words[-1].end - batched_model = BatchedInferencePipeline(model=model, use_vad_model=False) - result, info = batched_model.transcribe(jfk_path, word_timestamps=True) - assert info.language == "en" - assert info.language_probability > 0.7 - segments = [] - for segment in result: - segments.append( - {"start": segment.start, "end": segment.end, "text": segment.text} - ) - - assert len(segments) == 1 - assert segment.text == ( - " And so my fellow Americans ask not what your country can do for you, " - "ask what you can do for your country." - ) - - -def test_batched_transcribe(physcisworks_path): - model = WhisperModel("tiny") - batched_model = BatchedInferencePipeline(model=model) - result, info = batched_model.transcribe(physcisworks_path, batch_size=16) - assert info.language == "en" - assert info.language_probability > 0.7 - segments = [] - for segment in result: - segments.append( - {"start": segment.start, "end": segment.end, "text": segment.text} - ) - # number of near 30 sec segments - assert len(segments) == 8 - - result, info = batched_model.transcribe( - physcisworks_path, - batch_size=16, - without_timestamps=False, - word_timestamps=True, - ) - segments = [] - for segment in result: - assert segment.words is not None - segments.append( - {"start": segment.start, "end": segment.end, "text": segment.text} - ) - assert len(segments) > 8 def test_prefix_with_timestamps(jfk_path): @@ -145,13 +101,6 @@ def test_stereo_diarization(data_dir): assert transcription == "The horizon seems extremely distant." -def test_multisegment_lang_id(physcisworks_path): - model = WhisperModel("tiny") - language_info = model.detect_language_multi_segment(physcisworks_path) - assert language_info["language_code"] == "en" - assert language_info["language_confidence"] > 0.8 - - def test_suppressed_tokens_minus_1(): model = WhisperModel("tiny.en")