diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 7611237a..9b32053d 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -9,6 +9,7 @@ from inspect import signature from math import ceil from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union +from dataclasses import dataclass, field import ctranslate2 import numpy as np @@ -51,76 +52,92 @@ class Segment(NamedTuple): temperature: Optional[float] = 1.0 -# Added additional parameters for multilingual videos and fixes below -class TranscriptionOptions(NamedTuple): - beam_size: int - best_of: int - patience: float - length_penalty: float - repetition_penalty: float - no_repeat_ngram_size: int - log_prob_threshold: Optional[float] - log_prob_low_threshold: Optional[float] - no_speech_threshold: Optional[float] - compression_ratio_threshold: Optional[float] - condition_on_previous_text: bool - prompt_reset_on_temperature: float - temperatures: List[float] - initial_prompt: Optional[Union[str, Iterable[int]]] - prefix: Optional[str] - suppress_blank: bool - suppress_tokens: Optional[List[int]] - without_timestamps: bool - max_initial_timestamp: float - word_timestamps: bool - prepend_punctuations: str - append_punctuations: str - multilingual: bool - output_language: Optional[str] - max_new_tokens: Optional[int] - clip_timestamps: Union[str, List[float]] - hallucination_silence_threshold: Optional[float] - hotwords: Optional[str] - - -class TranscriptionInfo(NamedTuple): - language: str - language_probability: float - duration: float - duration_after_vad: float - all_language_probs: Optional[List[Tuple[str, float]]] - transcription_options: TranscriptionOptions - vad_options: VadOptions - - -# The code below is originally from HF pipeline and is used in whisper-x -# (https://github.com/m-bain/whisperX) and adapted for faster_whisper +@dataclass +class TranscriptionConfig: + # Core parameters + beam_size: int = 5 + best_of: int = 5 + patience: float = 1.0 + length_penalty: float = 1.0 + repetition_penalty: float = 1.0 + no_repeat_ngram_size: int = 0 + + # Threshold parameters + log_prob_threshold: Optional[float] = -1.0 + log_prob_low_threshold: Optional[float] = None + no_speech_threshold: Optional[float] = 0.6 + compression_ratio_threshold: Optional[float] = 2.4 + language_threshold: float = 0.7 + speech_percentage_threshold: float = 0.02 + + # Temperature settings + prompt_reset_on_temperature: float = 0.5 + temperatures: List[float] = field( + default_factory=lambda: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + ) + + # Text processing + initial_prompt: Optional[Union[str, Iterable[int]]] = None + prefix: Optional[str] = None + suppress_blank: bool = True + suppress_tokens: Optional[List[int]] = field(default_factory=lambda: [-1]) + without_timestamps: bool = True + word_timestamps: bool = False + prepend_punctuations: str = "\"'"¿([{-" + append_punctuations: str = "\"'.。,,!!??::")]}、" + + # Language and processing options + multilingual: bool = False + output_language: Optional[str] = None + vad_filter: bool = False + vad_parameters: Optional[Union[dict, VadOptions]] = None + language_detection_segments: int = 4 + vad_min_silence_duration: int = 2500 + + # Size and timing parameters + max_new_tokens: Optional[int] = None + max_initial_timestamp: float = 0.0 + chunk_length: Optional[int] = None + clip_timestamps: Union[str, List[float]] = "0" + + # Advanced options + condition_on_previous_text: bool = False + hallucination_silence_threshold: Optional[float] = None + hotwords: Optional[str] = None + + def __post_init__(self): + if self.beam_size < 1: + raise ValueError("beam_size must be at least 1") + if not isinstance(self.temperatures, (list, tuple)): + raise TypeError("temperatures must be a list or tuple of floats") + if any(not isinstance(t, (int, float)) for t in self.temperatures): + raise TypeError("all temperatures must be numeric") class BatchedInferencePipeline: - """ - Huggingface Pipeline wrapper for WhisperModel. - Copyright (c) 2022, Max Bain - All rights reserved. - Modified by Mobius Labs GmbH - """ def __init__( self, model, - options: Optional[NamedTuple] = None, + config: Optional[TranscriptionConfig] = None, tokenizer=None, language: Optional[str] = None, ): self.model: WhisperModel = model self.tokenizer = tokenizer - self.options = options + self.config = config or TranscriptionConfig() self.preset_language = language self.last_speech_timestamp = 0.0 - def forward(self, features, chunks_metadata, **forward_params): + def forward( + self, + features: torch.Tensor, + chunks_metadata: List[dict], + ) -> Tuple[ctranslate2.StorageView, List[dict]]: encoder_output, outputs = self.model.generate_segment_batched( - features, self.tokenizer, forward_params + features, + self.tokenizer, + self.config ) segmented_outputs = [] @@ -129,46 +146,45 @@ def forward(self, features, chunks_metadata, **forward_params): duration = chunk_metadata["end_time"] - chunk_metadata["start_time"] segment_size = int(ceil(duration) * self.model.frames_per_second) segment_sizes.append(segment_size) - ( - subsegments, - seek, - single_timestamp_ending, - ) = self.model._split_segments_by_timestamps( - tokenizer=self.tokenizer, - tokens=output["tokens"], - time_offset=chunk_metadata["start_time"], - segment_size=segment_size, - segment_duration=duration, - seek=0, - ) - segmented_outputs.append( - [ - dict( - text=self.tokenizer.decode(subsegment["tokens"]), - avg_logprob=output["avg_logprob"], - no_speech_prob=output["no_speech_prob"], - tokens=subsegment["tokens"], - start=subsegment["start"], - end=subsegment["end"], - compression_ratio=get_compression_ratio( - self.tokenizer.decode(subsegment["tokens"]) - ), - ) - for subsegment in subsegments - ] + + subsegments, seek, single_timestamp_ending = ( + self.model._split_segments_by_timestamps( + tokenizer=self.tokenizer, + tokens=output["tokens"], + time_offset=chunk_metadata["start_time"], + segment_size=segment_size, + segment_duration=duration, + seek=0, + ) ) - if forward_params["word_timestamps"]: + + segmented_outputs.append([ + dict( + text=self.tokenizer.decode(subsegment["tokens"]), + avg_logprob=output["avg_logprob"], + no_speech_prob=output["no_speech_prob"], + tokens=subsegment["tokens"], + start=subsegment["start"], + end=subsegment["end"], + compression_ratio=get_compression_ratio( + self.tokenizer.decode(subsegment["tokens"]) + ), + ) + for subsegment in subsegments + ]) + + if self.config.word_timestamps: self.last_speech_timestamp = self.model.add_word_timestamps( segmented_outputs, self.tokenizer, encoder_output, segment_sizes, - forward_params["prepend_punctuations"], - forward_params["append_punctuations"], + self.config.prepend_punctuations, + self.config.append_punctuations, self.last_speech_timestamp, ) - return segmented_outputs + return encoder_output, segmented_outputs def get_language_and_tokenizer( self, audio, task: Optional[str] = None, language: Optional[str] = None @@ -207,132 +223,10 @@ def get_language_and_tokenizer( def transcribe( self, audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], - language: Optional[str] = None, - task: str = None, log_progress: bool = False, - beam_size: int = 5, - best_of: int = 5, - patience: float = 1, - length_penalty: float = 1, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - temperature: Union[float, List[float], Tuple[float, ...]] = [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - ], - compression_ratio_threshold: Optional[float] = 2.4, - log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = 0.6, - initial_prompt: Optional[Union[str, Iterable[int]]] = None, - prefix: Optional[str] = None, - suppress_blank: bool = True, - suppress_tokens: Optional[List[int]] = [-1], - without_timestamps: bool = True, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - vad_filter: bool = True, - vad_parameters: Optional[Union[dict, VadOptions]] = None, - max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, - clip_timestamps: Optional[List[dict]] = None, batch_size: int = 16, - hotwords: Optional[str] = None, - ) -> Tuple[Iterable[Segment], TranscriptionInfo]: - """transcribe audio in chunks in batched fashion and return with language info. - - Arguments: - audio: Path to the input file (or a file-like object), or the audio waveform. - language: The language spoken in the audio. It should be a language code such - as "en" or "fr". If not set, the language will be detected in the first 30 seconds - of audio. - task: Task to execute (transcribe or translate). - log_progress: whether to show progress bar or not. - beam_size: Beam size to use for decoding. - best_of: Number of candidates when sampling with non-zero temperature. - patience: Beam search patience factor. - length_penalty: Exponential length penalty constant. - repetition_penalty: Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). - temperature: Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - compression_ratio_threshold: If the gzip compression ratio is above this value, - treat as failed. - log_prob_threshold: If the average log probability over sampled tokens is - below this value, treat as failed. - log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - whereas log_prob_threshold also looks for appropriate no_speech_threshold value. - This value should be less than log_prob_threshold. - no_speech_threshold: If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - initial_prompt: Optional text string or iterable of token ids to provide as a - prompt for the first window. - prefix: Optional text to provide as a prefix for the first window. - suppress_blank: Suppress blank outputs at the beginning of the sampling. - suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in `tokenizer.non_speech_tokens()`. - without_timestamps: Only sample text tokens. - word_timestamps: Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - Set as False. - prepend_punctuations: If word_timestamps is True, merge these punctuation symbols - with the next word - append_punctuations: If word_timestamps is True, merge these punctuation symbols - with the previous word - vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available - parameters and default values in the class `VadOptions`). - max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - chunk_length: The length of audio segments. If it is not None, it will overwrite the - default chunk_length of the FeatureExtractor. - clip_timestamps: Optionally provide list of dictionaries each containing "start" and - "end" keys that specify the start and end of the voiced region within - `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used. - batch_size: the maximum number of parallel requests to model for decoding. - hotwords: - Hotwords/hint phrases to the model. Has no effect if prefix is not None. - - Static params: (Fixed for batched version) - max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. - multilingual: If True, perform transcription on multilingual videos. Set as False. - output_language: Valid only if multilingual is set to True. - Specifies the string representing the output language. One of - 'en' (English) or 'hybrid' (code-switched transcription). set as None. - condition_on_previous_text: If True, the previous output of the model is provided - as a prompt for the next window; disabling may make the text inconsistent across - windows, but the model becomes less prone to getting stuck in a failure loop, - such as repetition looping or timestamps going out of sync. Set as False - prompt_reset_on_temperature: Resets prompt if temperature is above this value. - Arg has effect only if condition_on_previous_text is True. Set at 0.5 - #TODO: support "hallucination_silence_threshold" when "word_timestamps=True" - hallucination_silence_threshold: Optional[float] - When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected. set as None. - - unused: - language_detection_threshold: If the maximum probability of the language tokens is - higher than this value, the language is detected. - language_detection_segments: Number of segments to consider for the language detection. - - - Returns: - A tuple with: - - - a generator over transcribed segments - - an instance of TranscriptionInfo - """ - + ) -> Tuple[Iterable[Segment], 'TranscriptionInfo']: + options = self.config sampling_rate = self.model.feature_extractor.sampling_rate if isinstance(audio, np.ndarray): @@ -341,26 +235,26 @@ def transcribe( audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate - chunk_length = chunk_length or self.model.feature_extractor.chunk_length - # if no segment split is provided, use vad_model and generate segments + chunk_length = options.chunk_length or self.model.feature_extractor.chunk_length + clip_timestamps = options.clip_timestamps if not clip_timestamps: - if vad_filter: - if vad_parameters is None: + vad_parameters = None + if options.vad_filter: + if options.vad_parameters is None: vad_parameters = VadOptions( max_speech_duration_s=chunk_length, min_silence_duration_ms=160, ) - elif isinstance(vad_parameters, dict): - if "max_speech_duration_s" in vad_parameters.keys(): - vad_parameters.pop("max_speech_duration_s") + elif isinstance(options.vad_parameters, dict): + if "max_speech_duration_s" in options.vad_parameters.keys(): + options.vad_parameters.pop("max_speech_duration_s") vad_parameters = VadOptions( - **vad_parameters, max_speech_duration_s=chunk_length + **options.vad_parameters, max_speech_duration_s=chunk_length ) active_segments = get_speech_timestamps(audio, vad_parameters) clip_timestamps = merge_segments(active_segments, vad_parameters) - # run the audio if it is less than 30 sec even without clip_timestamps elif duration < chunk_length: clip_timestamps = [{"start": 0, "end": audio.shape[0]}] else: @@ -369,68 +263,36 @@ def transcribe( "Set 'vad_filter' to True or provide 'clip_timestamps'." ) if self.model.model.is_multilingual: - language = language or self.preset_language - elif language != "en": - if language is not None: + language = options.output_language or self.preset_language + elif options.output_language != "en": + if options.output_language is not None: self.model.logger.warning( - f"English-only model is used, but {language} language is" + f"English-only model is used, but {options.output_language} language is" " chosen, setting language to 'en'." ) language = "en" + else: + language = "en" ( language, language_probability, task, all_language_probs, - ) = self.get_language_and_tokenizer(audio, task, language) + ) = self.get_language_and_tokenizer(audio, None, language) duration_after_vad = ( sum((segment["end"] - segment["start"]) for segment in clip_timestamps) / sampling_rate ) - # batched options: see the difference with default options in WhisperModel - batched_options = TranscriptionOptions( - beam_size=beam_size, - best_of=best_of, - patience=patience, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, - no_speech_threshold=no_speech_threshold, - compression_ratio_threshold=compression_ratio_threshold, - temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] - ), - initial_prompt=initial_prompt, - prefix=prefix, - suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - max_new_tokens=max_new_tokens, - hotwords=hotwords, - word_timestamps=word_timestamps, - hallucination_silence_threshold=None, - condition_on_previous_text=False, - clip_timestamps="0", - prompt_reset_on_temperature=0.5, - multilingual=False, - output_language=None, - without_timestamps=without_timestamps, - max_initial_timestamp=0.0, - ) - info = TranscriptionInfo( language=language, language_probability=language_probability, duration=duration, duration_after_vad=duration_after_vad, - transcription_options=batched_options, - vad_options=None, + transcription_options=options, + vad_options=options.vad_parameters, all_language_probs=all_language_probs, ) @@ -455,29 +317,37 @@ def transcribe( features, chunks_metadata, batch_size, - batched_options, + options, log_progress, ) return segments, info def _batched_segments_generator( - self, features, chunks_metadata, batch_size, options, log_progress - ): + self, + features: torch.Tensor, + chunks_metadata: List[dict], + batch_size: int, + options: TranscriptionConfig, + log_progress: bool + ) -> Iterator[Segment]: + if len(features) == 0: + return + pbar = tqdm(total=len(features), disable=not log_progress, position=0) seg_idx = 0 + for i in range(0, len(features), batch_size): - results = self.forward( + encoder_output, segmented_outputs = self.forward( features[i : i + batch_size], chunks_metadata[i : i + batch_size], - **options._asdict(), ) - for result in results: - for segment in result: + for batch_segments in segmented_outputs: + for segment in batch_segments: seg_idx += 1 yield Segment( - seek=int(result[-1]["end"] * self.model.frames_per_second), + seek=int(batch_segments[-1]["end"] * self.model.frames_per_second), id=seg_idx, text=segment["text"], start=round(segment["start"], 3), @@ -485,23 +355,31 @@ def _batched_segments_generator( words=( None if not options.word_timestamps - else [Word(**word) for word in segment["words"]] + else [Word(**word) for word in segment.get("words", [])] ), tokens=segment["tokens"], avg_logprob=segment["avg_logprob"], no_speech_prob=segment["no_speech_prob"], compression_ratio=segment["compression_ratio"], ) - pbar.update(1) pbar.close() - # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None self.last_speech_timestamp = 0.0 +class TranscriptionInfo(NamedTuple): + language: str + language_probability: float + duration: float + duration_after_vad: float + all_language_probs: Optional[List[Tuple[str, float]]] + transcription_options: TranscriptionConfig + vad_options: Optional[VadOptions] + + class WhisperModel: def __init__( self, @@ -514,39 +392,11 @@ def __init__( download_root: Optional[str] = None, local_files_only: bool = False, files: dict = None, + config: Optional[TranscriptionConfig] = None, **model_kwargs, ): - """Initializes the Whisper model. - - Args: - model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, - small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, - large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a - converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub. - When a size or a model ID is configured, the converted model is downloaded - from the Hugging Face Hub. - device: Device to use for computation ("cpu", "cuda", "auto"). - device_index: Device ID to use. - The model can also be loaded on multiple GPUs by passing a list of IDs - (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel - when transcribe() is called from multiple Python threads (see also num_workers). - compute_type: Type to use for computation. - See https://opennmt.net/CTranslate2/quantization.html. - cpu_threads: Number of threads to use when running on CPU (4 by default). - A non zero value overrides the OMP_NUM_THREADS environment variable. - num_workers: When transcribe() is called from multiple Python threads, - having multiple workers enables true parallelism when running the model - (concurrent calls to self.model.generate() will run in parallel). - This can improve the global throughput at the cost of increased memory usage. - download_root: Directory where the models should be saved. If not set, the models - are saved in the standard Hugging Face cache directory. - local_files_only: If True, avoid downloading the file and return the path to the - local cached file if it exists. - files: Load model files from the memory. This argument is a dictionary mapping file names - to file contents as file-like or bytes objects. If this is set, model_path acts as an - identifier for this model. - """ self.logger = get_logger() + self.config = config or TranscriptionConfig() tokenizer_bytes, preprocessor_bytes = None, None if files: @@ -562,7 +412,6 @@ def __init__( cache_dir=download_root, ) self.device = device - # set the random seed to make sure consistency across runs ctranslate2.set_random_seed(42) self.model = ctranslate2.models.Whisper( model_path, @@ -603,7 +452,6 @@ def __init__( @property def supported_languages(self) -> List[str]: - """The languages supported by the model.""" return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: @@ -627,131 +475,9 @@ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: def transcribe( self, audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], - language: Optional[str] = None, - task: str = "transcribe", - beam_size: int = 5, - best_of: int = 5, - patience: float = 1, - length_penalty: float = 1, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - temperature: Union[float, List[float], Tuple[float, ...]] = [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - ], - compression_ratio_threshold: Optional[float] = 2.4, - log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - prompt_reset_on_temperature: float = 0.5, - initial_prompt: Optional[Union[str, Iterable[int]]] = None, - prefix: Optional[str] = None, - suppress_blank: bool = True, - suppress_tokens: Optional[List[int]] = [-1], - without_timestamps: bool = False, - max_initial_timestamp: float = 1.0, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - multilingual: bool = False, - output_language: Optional[str] = None, - vad_filter: bool = False, - vad_parameters: Optional[Union[dict, VadOptions]] = None, - max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - hotwords: Optional[str] = None, - language_detection_threshold: Optional[float] = None, - language_detection_segments: int = 1, + config: Optional[TranscriptionConfig] = None, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: - """Transcribes an input file. - - Arguments: - audio: Path to the input file (or a file-like object), or the audio waveform. - language: The language spoken in the audio. It should be a language code such - as "en" or "fr". If not set, the language will be detected in the first 30 seconds - of audio. - task: Task to execute (transcribe or translate). - beam_size: Beam size to use for decoding. - best_of: Number of candidates when sampling with non-zero temperature. - patience: Beam search patience factor. - length_penalty: Exponential length penalty constant. - repetition_penalty: Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). - temperature: Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - compression_ratio_threshold: If the gzip compression ratio is above this value, - treat as failed. - log_prob_threshold: If the average log probability over sampled tokens is - below this value, treat as failed. - log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - wheras log_prob_threshold also looks for appropriate no_speech_threshold value. - This value should be less than log_prob_threshold. - no_speech_threshold: If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - condition_on_previous_text: If True, the previous output of the model is provided - as a prompt for the next window; disabling may make the text inconsistent across - windows, but the model becomes less prone to getting stuck in a failure loop, - such as repetition looping or timestamps going out of sync. - prompt_reset_on_temperature: Resets prompt if temperature is above this value. - Arg has effect only if condition_on_previous_text is True. - initial_prompt: Optional text string or iterable of token ids to provide as a - prompt for the first window. - prefix: Optional text to provide as a prefix for the first window. - suppress_blank: Suppress blank outputs at the beginning of the sampling. - suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in `tokenizer.non_speech_tokens()`. - without_timestamps: Only sample text tokens. - max_initial_timestamp: The initial timestamp cannot be later than this. - word_timestamps: Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - prepend_punctuations: If word_timestamps is True, merge these punctuation symbols - with the next word - append_punctuations: If word_timestamps is True, merge these punctuation symbols - with the previous word - multilingual: If True, perform transcription on multilingual videos - and return the transcript based - on the 'output_language' flag. - output_language: Valid only if multilingual is set to True. - Specifies the string representing the output language. One of - 'en' (English) or 'hybrid' (code-switched transcription). - vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available - parameters and default values in the class `VadOptions`). - max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - chunk_length: The length of audio segments. If it is not None, it will overwrite the - default chunk_length of the FeatureExtractor. - clip_timestamps: - Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to - process. The last end timestamp defaults to the end of the file. - vad_filter will be ignored if clip_timestamps is used. - hallucination_silence_threshold: - When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected - hotwords: - Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. - language_detection_threshold: If the maximum probability of the language tokens is higher - than this value, the language is detected. - language_detection_segments: Number of segments to consider for the language detection. - Returns: - A tuple with: - - - a generator over transcribed segments - - an instance of TranscriptionInfo - """ - + config = config or self.config sampling_rate = self.feature_extractor.sampling_rate if isinstance(audio, np.ndarray): @@ -766,11 +492,12 @@ def transcribe( "Processing audio with duration %s", format_timestamp(duration) ) - if vad_filter and clip_timestamps == "0": - if vad_parameters is None: + if config.vad_filter and config.clip_timestamps == "0": + vad_parameters = None + if config.vad_parameters is None: vad_parameters = VadOptions() - elif isinstance(vad_parameters, dict): - vad_parameters = VadOptions(**vad_parameters) + elif isinstance(config.vad_parameters, dict): + vad_parameters = VadOptions(**config.vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) audio = torch.cat(audio_chunks, dim=0) @@ -799,81 +526,23 @@ def transcribe( to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 features = self.feature_extractor( - audio, chunk_length=chunk_length, to_cpu=to_cpu + audio, chunk_length=config.chunk_length, to_cpu=to_cpu ) encoder_output = None all_language_probs = None - # setting output_language for multilingual videos - if multilingual: - if output_language is None: - output_language = "en" - elif output_language not in ["en", "hybrid"]: - raise ValueError("Output language needs to be one of 'en'/'hybrid'.") + language = config.output_language + language_probability = 1 - # detecting the language if not provided if language is None: if not self.model.is_multilingual: language = "en" language_probability = 1 else: - if ( - language_detection_segments is None - or language_detection_segments < 1 - ): - language_detection_segments = 1 - start_timestamp = ( - float(clip_timestamps.split(",")[0]) - if isinstance(clip_timestamps, str) - else clip_timestamps[0] - ) - content_frames = ( - features.shape[-1] - self.feature_extractor.nb_max_frames - ) - seek = ( - int(start_timestamp * self.frames_per_second) - if start_timestamp * self.frames_per_second < content_frames - else 0 + language, language_probability, all_language_probs = self.detect_language( + audio ) - end_frames = min( - seek - + self.feature_extractor.nb_max_frames - * language_detection_segments, - content_frames, - ) - detected_language_info = {} - while seek <= end_frames: - segment = features[ - :, seek : seek + self.feature_extractor.nb_max_frames - ] - encoder_output = self.encode(segment) - # results is a list of tuple[str, float] with language names and - # probabilities. - results = self.model.detect_language(encoder_output)[0] - # Parse language names to strip out markers - all_language_probs = [ - (token[2:-2], prob) for (token, prob) in results - ] - # Get top language token and probability - language, language_probability = all_language_probs[0] - if ( - language_detection_threshold is None - or language_probability > language_detection_threshold - ): - break - detected_language_info.setdefault(language, []).append( - language_probability - ) - seek += segment.shape[-1] - else: - # If no language detected for all segments, the majority vote of the highest - # projected languages for all segments is used to determine the language. - language = max( - detected_language_info, - key=lambda lang: len(detected_language_info[lang]), - ) - language_probability = max(detected_language_info[language]) self.logger.info( "Detected language '%s' with probability %.2f", @@ -888,53 +557,14 @@ def transcribe( ) language = "en" - language_probability = 1 - tokenizer = Tokenizer( self.hf_tokenizer, self.model.is_multilingual, - task=task, + task="transcribe", language=language, ) - options = TranscriptionOptions( - beam_size=beam_size, - best_of=best_of, - patience=patience, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, - no_speech_threshold=no_speech_threshold, - compression_ratio_threshold=compression_ratio_threshold, - condition_on_previous_text=condition_on_previous_text, - prompt_reset_on_temperature=prompt_reset_on_temperature, - temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] - ), - initial_prompt=initial_prompt, - prefix=prefix, - suppress_blank=suppress_blank, - suppress_tokens=( - get_suppressed_tokens(tokenizer, suppress_tokens) - if suppress_tokens - else suppress_tokens - ), - without_timestamps=without_timestamps, - max_initial_timestamp=max_initial_timestamp, - word_timestamps=word_timestamps, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - multilingual=multilingual, - output_language=output_language, - max_new_tokens=max_new_tokens, - clip_timestamps=clip_timestamps, - hallucination_silence_threshold=hallucination_silence_threshold, - hotwords=hotwords, - ) - - segments = self.generate_segments(features, tokenizer, options, encoder_output) + segments = self.generate_segments(features, tokenizer, config, encoder_output) if speech_chunks: segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) @@ -944,114 +574,39 @@ def transcribe( language_probability=language_probability, duration=duration, duration_after_vad=duration_after_vad, - transcription_options=options, - vad_options=vad_parameters, + transcription_options=config, + vad_options=config.vad_parameters, all_language_probs=all_language_probs, ) return segments, info - def _split_segments_by_timestamps( - self, - tokenizer: Tokenizer, - tokens: List[int], - time_offset: float, - segment_size: int, - segment_duration: float, - seek: int, - ) -> List[List[int]]: - current_segments = [] - single_timestamp_ending = ( - len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] - ) - - consecutive_timestamps = [ - i - for i in range(len(tokens)) - if i > 0 - and tokens[i] >= tokenizer.timestamp_begin - and tokens[i - 1] >= tokenizer.timestamp_begin - ] - - if len(consecutive_timestamps) > 0: - slices = list(consecutive_timestamps) - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin - end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin - start_time = ( - time_offset + start_timestamp_position * self.time_precision - ) - end_time = time_offset + end_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=start_time, - end=end_time, - tokens=sliced_tokens, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_position = ( - tokens[last_slice - 1] - tokenizer.timestamp_begin - ) - seek += last_timestamp_position * self.input_stride - - else: - duration = segment_duration - timestamps = [ - token for token in tokens if token >= tokenizer.timestamp_begin - ] - if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: - last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin - duration = last_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=time_offset, - end=time_offset + duration, - tokens=tokens, - ) - ) + def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - seek += segment_size + if features.ndim == 2: + features = features.unsqueeze(0) + features = get_ctranslate2_storage(features) - return current_segments, seek, single_timestamp_ending + return self.model.encode(features, to_cpu=to_cpu) def generate_segments( self, features: torch.Tensor, tokenizer: Tokenizer, - options: TranscriptionOptions, + config: TranscriptionConfig, encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames - content_duration = float(content_frames * self.feature_extractor.time_per_frame) - - if isinstance(options.clip_timestamps, str): - options = options._replace( - clip_timestamps=[ - float(ts) - for ts in ( - options.clip_timestamps.split(",") - if options.clip_timestamps - else [] - ) - ] - ) + content_duration = float( + content_frames * self.feature_extractor.time_per_frame + ) + + if isinstance(config.clip_timestamps, str): + config.clip_timestamps = [ + float(ts) for ts in config.clip_timestamps.split(",") if ts + ] seek_points: List[int] = [ - round(ts * self.frames_per_second) for ts in options.clip_timestamps + round(ts * self.frames_per_second) for ts in config.clip_timestamps ] if len(seek_points) == 0: seek_points.append(0) @@ -1061,27 +616,21 @@ def generate_segments( zip(seek_points[::2], seek_points[1::2]) ) - punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" - idx = 0 clip_idx = 0 seek = seek_clips[clip_idx][0] all_tokens = [] prompt_reset_since = 0 - if options.initial_prompt is not None: - if isinstance(options.initial_prompt, str): - initial_prompt = " " + options.initial_prompt.strip() + if config.initial_prompt is not None: + if isinstance(config.initial_prompt, str): + initial_prompt = " " + config.initial_prompt.strip() initial_prompt_tokens = tokenizer.encode(initial_prompt) all_tokens.extend(initial_prompt_tokens) else: - all_tokens.extend(options.initial_prompt) + all_tokens.extend(config.initial_prompt) last_speech_timestamp = 0.0 - # NOTE: This loop is obscurely flattened to make the diff readable. - # A later commit should turn this into a simpler nested loop. - # for seek_clip_start, seek_clip_end in seek_clips: - # while seek < seek_clip_end while clip_idx < len(seek_clips): seek_clip_start, seek_clip_end = seek_clips[clip_idx] if seek_clip_end > content_frames: @@ -1117,30 +666,12 @@ def generate_segments( if encoder_output is None: encoder_output = self.encode(segment) - # Perform language detection at every segment to update task based on output language, - # if the language is english, task is transcribe, - # else the task is translate to english (default) - # or transcribe if 'output_language' is 'hybrid'. - if options.multilingual: - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] - if options.output_language == "en" and language != "en": - task = "translate" - else: - task = "transcribe" - - # Update tokenizer based on task and language - tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>") - tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) - tokenizer.language_code = language - # Update prompt based on task and language prompt = self.get_prompt( tokenizer, previous_tokens, - without_timestamps=options.without_timestamps, - prefix=options.prefix if seek == 0 else None, - hotwords=options.hotwords, + without_timestamps=config.without_timestamps, + prefix=config.prefix if seek == 0 else None, + hotwords=config.hotwords, ) if seek > 0 or encoder_output is None: @@ -1151,70 +682,39 @@ def generate_segments( avg_logprob, temperature, compression_ratio, - ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) + ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, config) - if options.no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > options.no_speech_threshold + if config.no_speech_threshold is not None: + should_skip = result.no_speech_prob > config.no_speech_threshold if ( - options.log_prob_threshold is not None - and avg_logprob > options.log_prob_threshold + config.log_prob_threshold is not None + and avg_logprob > config.log_prob_threshold ): - # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False if should_skip: self.logger.debug( "No speech threshold is met (%f > %f)", result.no_speech_prob, - options.no_speech_threshold, + config.no_speech_threshold, ) - # Skip if the logprob is very low (below the threshold value), - # despite no_speech_prob being low (ex: Too ambiguous outputs) - if options.log_prob_low_threshold: - if avg_logprob < options.log_prob_low_threshold: + if config.log_prob_low_threshold: + if avg_logprob < config.log_prob_low_threshold: should_skip = True self.logger.debug( "log prob low threshold is met (%f > %f)", avg_logprob, - options.log_prob_low_threshold, + config.log_prob_low_threshold, ) if should_skip: - # fast-forward to the next segment boundary seek += segment_size continue tokens = result.sequences_ids[0] - previous_seek = seek - - # anomalous words are very long/short/improbable - def word_anomaly_score(word: dict) -> float: - probability = word.get("probability", 0.0) - duration = word["end"] - word["start"] - score = 0.0 - if probability < 0.15: - score += 1.0 - if duration < 0.133: - score += (0.133 - duration) * 15 - if duration > 2.0: - score += duration - 2.0 - return score - - def is_segment_anomaly(segment: Optional[dict]) -> bool: - if segment is None or not segment["words"]: - return False - words = [w for w in segment["words"] if w["word"] not in punctuation] - words = words[:8] - score = sum(word_anomaly_score(w) for w in words) - return score >= 3 or score + 0.01 >= len(words) - - def next_words_segment(segments: List[dict]) -> Optional[dict]: - return next((s for s in segments if s["words"]), None) - ( current_segments, seek, @@ -1228,14 +728,14 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: seek=seek, ) - if options.word_timestamps: - self.add_word_timestamps( + if config.word_timestamps: + last_speech_timestamp = self.add_word_timestamps( [current_segments], tokenizer, encoder_output, segment_size, - options.prepend_punctuations, - options.append_punctuations, + config.prepend_punctuations, + config.append_punctuations, last_speech_timestamp=last_speech_timestamp, ) if not single_timestamp_ending: @@ -1243,57 +743,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * self.frames_per_second) - # skip silence before possible hallucinations - if options.hallucination_silence_threshold is not None: - threshold = options.hallucination_silence_threshold - - # if first segment might be a hallucination, skip leading silence - first_segment = next_words_segment(current_segments) - if first_segment is not None and is_segment_anomaly(first_segment): - gap = first_segment["start"] - time_offset - if gap > threshold: - seek = previous_seek + round(gap * self.frames_per_second) - continue - - # skip silence before any possible hallucination that is surrounded - # by silence or more hallucinations - hal_last_end = last_speech_timestamp - for si in range(len(current_segments)): - segment = current_segments[si] - if not segment["words"]: - continue - if is_segment_anomaly(segment): - next_segment = next_words_segment( - current_segments[si + 1 :] - ) - if next_segment is not None: - hal_next_start = next_segment["words"][0]["start"] - else: - hal_next_start = time_offset + segment_duration - silence_before = ( - segment["start"] - hal_last_end > threshold - or segment["start"] < threshold - or segment["start"] - time_offset < 2.0 - ) - silence_after = ( - hal_next_start - segment["end"] > threshold - or is_segment_anomaly(next_segment) - or window_end_time - segment["end"] < 2.0 - ) - if silence_before and silence_after: - seek = round( - max(time_offset + 1, segment["start"]) - * self.frames_per_second - ) - if content_duration - segment["end"] < threshold: - seek = content_frames - current_segments[si:] = [] - break - hal_last_end = segment["end"] - - last_word_end = get_end(current_segments) - if last_word_end is not None: - last_speech_timestamp = last_word_end for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -1316,100 +765,88 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: compression_ratio=compression_ratio, no_speech_prob=result.no_speech_prob, words=( - [Word(**word) for word in segment["words"]] - if options.word_timestamps + [Word(**word) for word in segment.get("words", [])] + if config.word_timestamps else None ), ) if ( - not options.condition_on_previous_text - or temperature > options.prompt_reset_on_temperature + not config.condition_on_previous_text + or temperature > config.prompt_reset_on_temperature ): - if options.condition_on_previous_text: + if config.condition_on_previous_text: self.logger.debug( "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f", temperature, - options.prompt_reset_on_temperature, + config.prompt_reset_on_temperature, ) prompt_reset_since = len(all_tokens) - def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: - # When the model is running on multiple GPUs, the encoder output should be moved - # to the CPU since we don't know which GPU will handle the next job. - to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - - if features.ndim == 2: - features = features.unsqueeze(0) - features = get_ctranslate2_storage(features) - - return self.model.encode(features, to_cpu=to_cpu) - def generate_with_fallback( self, encoder_output: ctranslate2.StorageView, prompt: List[int], tokenizer: Tokenizer, - options: TranscriptionOptions, + config: TranscriptionConfig, ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: decode_result = None all_results = [] below_cr_threshold_results = [] max_initial_timestamp_index = int( - round(options.max_initial_timestamp / self.time_precision) + round(config.max_initial_timestamp / self.time_precision) ) - if options.max_new_tokens is not None: - max_length = len(prompt) + options.max_new_tokens + if config.max_new_tokens is not None: + max_length = len(prompt) + config.max_new_tokens else: max_length = self.max_length if max_length > self.max_length: raise ValueError( - f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` " + f"The length of the prompt is {len(prompt)}, and the max_new_tokens " f"{max_length - len(prompt)}. Thus, the combined length of the prompt " - f"and `max_new_tokens` is: {max_length}. This exceeds the " - f"`max_length` of the Whisper model: {self.max_length}. " + f"and max_new_tokens is: {max_length}. This exceeds the " + f"max_length of the Whisper model: {self.max_length}. " "You should either reduce the length of your prompt, or " - "reduce the value of `max_new_tokens`, " + "reduce the value of max_new_tokens, " f"so that their combined length is less that {self.max_length}." ) - for temperature in options.temperatures: + for temperature in config.temperatures: if temperature > 0: kwargs = { "beam_size": 1, - "num_hypotheses": options.best_of, + "num_hypotheses": config.best_of, "sampling_topk": 0, "sampling_temperature": temperature, } else: kwargs = { - "beam_size": options.beam_size, - "patience": options.patience, + "beam_size": config.beam_size, + "patience": config.patience, } result = self.model.generate( encoder_output, [prompt], - length_penalty=options.length_penalty, - repetition_penalty=options.repetition_penalty, - no_repeat_ngram_size=options.no_repeat_ngram_size, + length_penalty=config.length_penalty, + repetition_penalty=config.repetition_penalty, + no_repeat_ngram_size=config.no_repeat_ngram_size, max_length=max_length, return_scores=True, return_no_speech_prob=True, - suppress_blank=options.suppress_blank, - suppress_tokens=options.suppress_tokens, + suppress_blank=config.suppress_blank, + suppress_tokens=get_suppressed_tokens(tokenizer, config.suppress_tokens), max_initial_timestamp_index=max_initial_timestamp_index, **kwargs, )[0] tokens = result.sequences_ids[0] - # Recover the average log prob from the returned score. seq_len = len(tokens) - cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + cum_logprob = result.scores[0] * (seq_len ** config.length_penalty) avg_logprob = cum_logprob / (seq_len + 1) text = tokenizer.decode(tokens).strip() @@ -1425,48 +862,46 @@ def generate_with_fallback( needs_fallback = False - if options.compression_ratio_threshold is not None: - if compression_ratio > options.compression_ratio_threshold: - needs_fallback = True # too repetitive + if config.compression_ratio_threshold is not None: + if compression_ratio > config.compression_ratio_threshold: + needs_fallback = True self.logger.debug( "Compression ratio threshold is not met with temperature %.1f (%f > %f)", temperature, compression_ratio, - options.compression_ratio_threshold, + config.compression_ratio_threshold, ) else: below_cr_threshold_results.append(decode_result) if ( - options.log_prob_threshold is not None - and avg_logprob < options.log_prob_threshold + config.log_prob_threshold is not None + and avg_logprob < config.log_prob_threshold ): - needs_fallback = True # average log probability is too low + needs_fallback = True self.logger.debug( "Log probability threshold is not met with temperature %.1f (%f < %f)", temperature, avg_logprob, - options.log_prob_threshold, + config.log_prob_threshold, ) if ( - options.no_speech_threshold is not None - and result.no_speech_prob > options.no_speech_threshold - and options.log_prob_threshold is not None - and avg_logprob < options.log_prob_threshold + config.no_speech_threshold is not None + and result.no_speech_prob > config.no_speech_threshold + and config.log_prob_threshold is not None + and avg_logprob < config.log_prob_threshold ): - needs_fallback = False # silence + needs_fallback = False if not needs_fallback: break else: - # all failed, select the result with the highest average log probability decode_result = max( below_cr_threshold_results or all_results, key=lambda x: x[1] ) - # to pass final temperature for prompt_reset_on_temperature decode_result = ( decode_result[0], decode_result[1], @@ -1511,18 +946,95 @@ def get_prompt( return prompt + def _split_segments_by_timestamps( + self, + tokenizer: Tokenizer, + tokens: List[int], + time_offset: float, + segment_size: int, + segment_duration: float, + seek: int, + ) -> Tuple[List[dict], int, bool]: + current_segments = [] + single_timestamp_ending = ( + len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ) + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin + end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = time_offset + end_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + seek += segment_size + else: + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + + return current_segments, seek, single_timestamp_ending + def add_word_timestamps( self, - segments: List[dict], + segments: List[List[dict]], tokenizer: Tokenizer, encoder_output: ctranslate2.StorageView, - num_frames: int, + num_frames: Union[int, List[int]], prepend_punctuations: str, append_punctuations: str, last_speech_timestamp: float, ) -> float: if len(segments) == 0: - return + return last_speech_timestamp text_tokens = [] text_tokens_per_segment = [] @@ -1549,12 +1061,8 @@ def add_word_timestamps( median_duration = min(0.7, float(median_duration)) max_duration = median_duration * 2 - # hack: truncate long words at sentence boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. if len(word_durations) > 0: sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries - # are not longer than twice the median word duration. for i in range(1, len(alignment)): if alignment[i]["end"] - alignment[i]["start"] > max_duration: if alignment[i]["word"] in sentence_end_marks: @@ -1591,11 +1099,7 @@ def add_word_timestamps( saved_tokens += len(timing["tokens"]) word_index += 1 - # hack: truncate long words at segment boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. if len(words) > 0: - # ensure the first and second word after a pause is not longer than - # twice the median word duration. if words[0][ "end" ] - last_speech_timestamp > median_duration * 4 and ( @@ -1615,7 +1119,6 @@ def add_word_timestamps( words[0]["end"] = words[1]["start"] = boundary words[0]["start"] = max(0, words[0]["end"] - max_duration) - # prefer the segment-level start timestamp if the first word is too long. if ( subsegment["start"] < words[0]["end"] and subsegment["start"] - 0.5 > words[0]["start"] @@ -1627,7 +1130,6 @@ def add_word_timestamps( else: subsegment["start"] = words[0]["start"] - # prefer the segment-level end timestamp if the last word is too long. if ( subsegment["end"] > words[-1]["start"] and subsegment["end"] + 0.5 < words[-1]["end"] @@ -1645,11 +1147,11 @@ def add_word_timestamps( def find_alignment( self, tokenizer: Tokenizer, - text_tokens: List[int], + text_tokens: List[List[int]], encoder_output: ctranslate2.StorageView, - num_frames: int, + num_frames: Union[int, List[int]], median_filter_width: int = 7, - ) -> List[dict]: + ) -> List[List[dict]]: if len(text_tokens) == 0: return [] @@ -1661,21 +1163,14 @@ def find_alignment( median_filter_width=median_filter_width, ) return_list = [] - for result, text_token in zip(results, text_tokens): + for result, tokens in zip(results, text_tokens): text_token_probs = result.text_token_probs alignments = result.alignments text_indices = np.array([pair[0] for pair in alignments]) time_indices = np.array([pair[1] for pair in alignments]) - words, word_tokens = tokenizer.split_to_word_tokens( - text_token + [tokenizer.eot] - ) + words, word_tokens = tokenizer.split_to_word_tokens(tokens + [tokenizer.eot]) if len(word_tokens) <= 1: - # return on eot only - # >>> np.pad([], (1, 0)) - # array([0.]) - # This results in crashes when we lookup jump_times with float, like - # IndexError: arrays used as indices must be of integer (or boolean) type return [] word_boundaries = np.pad( np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) @@ -1698,12 +1193,12 @@ def find_alignment( [ dict( word=word, - tokens=tokens, + tokens=wtokens, start=start, end=end, probability=probability, ) - for word, tokens, start, end, probability in zip( + for word, wtokens, start, end, probability in zip( words, word_tokens, start_times, end_times, word_probabilities ) ] @@ -1714,22 +1209,25 @@ def generate_segment_batched( self, features: torch.Tensor, tokenizer: Tokenizer, - options: dict, + options: TranscriptionConfig, ): batch_size = features.shape[0] all_tokens = [] prompt_reset_since = 0 - if options["initial_prompt"] is not None: - initial_prompt = " " + options["initial_prompt"].strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) + if options.initial_prompt is not None: + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + else: + all_tokens.extend(options.initial_prompt) previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( tokenizer, previous_tokens, - without_timestamps=options["without_timestamps"], - prefix=options["prefix"], + without_timestamps=options.without_timestamps, + prefix=options.prefix, ) encoder_output = self.encode(features) @@ -1737,31 +1235,32 @@ def generate_segment_batched( result = self.model.generate( encoder_output, [prompt] * batch_size, - beam_size=options["beam_size"], - patience=options["patience"], - length_penalty=options["length_penalty"], + beam_size=options.beam_size, + patience=options.patience, + length_penalty=options.length_penalty, max_length=self.max_length, - suppress_blank=options["suppress_blank"], - suppress_tokens=options["suppress_tokens"], + suppress_blank=options.suppress_blank, + suppress_tokens=get_suppressed_tokens(tokenizer, options.suppress_tokens), return_scores=True, return_no_speech_prob=True, ) output = [] for res in result: - output.append({}) - # return scores seq_len = len(res.sequences_ids[0]) - cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) - output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) - - # return no speech prob - output[-1]["no_speech_prob"] = res.no_speech_prob - output[-1]["tokens"] = res.sequences_ids[0] + cum_logprob = res.scores[0] * (seq_len ** options.length_penalty) + avg_logprob = cum_logprob / (seq_len + 1) + output.append({ + "avg_logprob": avg_logprob, + "no_speech_prob": res.no_speech_prob, + "tokens": res.sequences_ids[0], + }) return encoder_output, output - def detect_language(self, audio: torch.Tensor): + def detect_language( + self, audio: torch.Tensor + ) -> Tuple[str, float, Optional[List[Tuple[str, float]]]]: to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ :, : self.feature_extractor.nb_max_frames @@ -1776,193 +1275,121 @@ def detect_language(self, audio: torch.Tensor): all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] return language, language_probability, all_language_probs - def detect_language_multi_segment( - self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None - ): - """ - Detect language based on N highly-confident segments of a language. - """ - # The threshold is used to decide if the audio is silence or not. - # The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent, - # the audio is considered as silence. - if not params: - params = { - "multilingual": False, - "speech_percentage_threshold": 0.02, - "language_detection_segments": 4, - "vad_filter": True, - "vad_min_silence_duration": 2500, - "language_threshold": 0.7, - } + def detect_language_multi_segment(self, audio: Union[str, BinaryIO, torch.Tensor]): + config = self.config - if params.get("multilingual", False): - logging.warning( - "lang_id is not supported for multilingual audios, detecting the major language." + if not config.multilingual: + self.logger.warning( + "Language detection is not supported for non-multilingual models; defaulting to the major language." ) - speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02) - language_threshold = params.get("language_threshold", 0.7) - num_detection_segments = params.get("language_detection_segments", 4) - vad_filter_enabled = params.get("vad_filter", True) - vad_params = dict( - min_silence_duration_ms=params.get("vad_min_silence_duration", 2500) - ) + speech_percentage_threshold = config.speech_percentage_threshold + language_threshold = config.language_threshold + num_detection_segments = config.language_detection_segments + vad_filter_enabled = config.vad_filter + vad_params = { + 'min_silence_duration_ms': config.vad_min_silence_duration + } if vad_filter_enabled: - vad_params = VadOptions(**vad_params) + vad_options = VadOptions(**vad_params) + else: + vad_options = None - # decode audio if it is not decoded already sampling_rate = self.feature_extractor.sampling_rate if not isinstance(audio, torch.Tensor): - audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate) + audio = decode_audio(audio, sampling_rate=sampling_rate) - # calculate duration of audio as number of seconds - # audio.shape[0] is the number of samples in the audio - # sampling_rate is the number of samples per second - # if we divide the number of samples by the number of samples per second, - # we get the duration in seconds duration = audio.shape[0] / sampling_rate - # Check if vad is enabled, and collect voiced segments if vad_filter_enabled: - # get chunks of audio that contain speech - speech_chunks = get_speech_timestamps(audio, vad_params) - # merge chunks of audio that contain speech into a single array - audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + speech_chunks = get_speech_timestamps(audio, vad_options) + audio_chunks, _ = collect_chunks(audio, speech_chunks) audio = torch.cat(audio_chunks, dim=0) - # calculate new duration of audio without silence duration_vad = audio.shape[0] / sampling_rate - logging.debug( - f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio" + self.logger.debug( + f"Language detection: VAD filter removed {duration - duration_vad:.2f} sec of audio" ) - # if the audio after VAD is less than 2% of the original audio, consider it as silence if duration_vad / duration < speech_percentage_threshold: return {"language_code": None, "language_confidence": 1.0} - # update duration to be the duration after VAD duration = duration_vad - # if the duration of the audio is less than 1 second, consider it as silence if duration < 1.0: return {"language_code": None, "language_confidence": 1.0} - # number of feature frames in 30 seconds of audio is 3000 nb_max_frames = self.feature_extractor.nb_max_frames - # extract features from audio with padding (default) to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 features = self.feature_extractor(audio, to_cpu=to_cpu) - # number of segments in the audio num_segments = features.shape[-1] // nb_max_frames - # more number of segments than possible with the duration of file if num_detection_segments > num_segments: - logging.warning( - f"Lang ID: Can not have more segments, setting {num_segments} segments." + self.logger.warning( + f"Language detection: Cannot have more segments; setting to {num_segments} segments." ) num_detection_segments = num_segments - # create a list of indices to randomly select segments from - indices = list(range(num_detection_segments)) - - # fix seed to get deterministic results + indices = list(range(num_segments)) random.seed(0) random.shuffle(indices) + indices = indices[:num_detection_segments] detected_languages = [] all_language_probabilities = defaultdict(list) confident_language_probabilities = defaultdict(list) num_confident_segments_per_language = defaultdict(int) - # Iterate over the randomly selected indices of the segments. - # - # For each segment, extract features and detect language. - # - # If the language is confident, add it to the list of confident segments for that language. - # - # If the number of confident segments for a language - # is greater than or equal to the number of detection segments, - # return the language and the average probability of the language. - # - # If we are unable to get sufficient number of confident predcitions, - # return the most frequently detected language with maximum probability. - # - # We need to get sufficient number of confident predictions per language, not in total. - for i in indices: segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] try: encoder_output = self.encode(segment_features) results = self.model.detect_language(encoder_output)[0] - - except ValueError as e: # or RuntimeError - logging.error(f"Inference error:{e}") - - # results is the list of classes (languages) and their probabilities (descending), - # for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...] - - # take top language token and probability - # and parse language token to strip out markers - # for eg: '<|de|>' -> 'de' + except ValueError as e: + self.logger.error(f"Inference error: {e}") + continue language_token = results[0][0] language = language_token[2:-2] - language_probability = results[0][1] detected_languages.append(language) all_language_probabilities[language].append(language_probability) - # only consider if the language prediction is confident if language_probability > language_threshold: num_confident_segments_per_language[language] += 1 - - # Add language and probability to the list of languages when it is confident confident_language_probabilities[language].append(language_probability) - # return the language when sufficient number of confident segments is achieved if ( num_confident_segments_per_language[language] >= num_detection_segments ): - # Considering the average probability of only confident segments - mean = sum(confident_language_probabilities[language]) / len( + mean_confidence = sum(confident_language_probabilities[language]) / len( confident_language_probabilities[language] ) return { "language_code": language, - "language_confidence": mean, + "language_confidence": mean_confidence, } - # if we are unable to get sufficient number of confident predictions, - # return the most frequently detected language. - # if there is a tie, return the one with maximum average probability. counter = Counter(detected_languages) - # Define the key function to select frequent language with attached probabilities - def key_func(language): - # Calculate the frequency of the language - frequency = counter[language] - - # Calculate the average probability of the language - prob_avg = sum(all_language_probabilities[language]) / len( - all_language_probabilities[language] + def key_func(lang): + frequency = counter[lang] + prob_avg = sum(all_language_probabilities[lang]) / len( + all_language_probabilities[lang] ) - return frequency, prob_avg if detected_languages: - # Use the key function to find the language with maximum frequency and probability max_language = max(detected_languages, key=key_func) max_probability = sum(all_language_probabilities[max_language]) / len( all_language_probabilities[max_language] ) - # Do additional checks for silence for non-confident case - # calculate RMS amplitude and DC offset dc_offset = audio.mean() audio_minus_dc_offset = audio - dc_offset is_silent = ( @@ -1978,10 +1405,40 @@ def key_func(language): "language_confidence": max_probability, } - # Language is not detected for any segment and none of prev conditions met return {"language_code": None, "language_confidence": 1.0} +def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: + segment = segment.contiguous() + segment = ctranslate2.StorageView.from_array( + segment if segment.is_cuda else segment.numpy() + ) + return segment + +def get_compression_ratio(text: str) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + +def get_suppressed_tokens( + tokenizer: Tokenizer, + suppress_tokens: Optional[List[int]], +) -> Optional[List[int]]: + if suppress_tokens is None: + suppress_tokens = [] + elif -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(tokenizer.non_speech_tokens) + suppress_tokens.extend( + [ + tokenizer.transcribe, + tokenizer.translate, + tokenizer.sot, + tokenizer.sot_prev, + tokenizer.sot_lm, + ] + ) + return tuple(sorted(set(suppress_tokens))) + def restore_speech_timestamps( segments: Iterable[Segment], speech_chunks: List[dict], @@ -1993,7 +1450,6 @@ def restore_speech_timestamps( if segment.words: words = [] for word in segment.words: - # Ensure the word start and end times are resolved to the same chunk. middle = (word.start + word.end) / 2 chunk_index = ts_map.get_chunk_index(middle) word = word._replace( @@ -2016,55 +1472,13 @@ def restore_speech_timestamps( yield segment - -def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: - segment = segment.contiguous() - segment = ctranslate2.StorageView.from_array( - segment if segment.is_cuda else segment.numpy() - ) # torch cpu tensors don't implement __array_interface__ - # https://github.com/pytorch/pytorch/issues/51156 - return segment - - -def get_compression_ratio(text: str) -> float: - text_bytes = text.encode("utf-8") - return len(text_bytes) / len(zlib.compress(text_bytes)) - - -def get_suppressed_tokens( - tokenizer: Tokenizer, - suppress_tokens: Tuple[int], -) -> Optional[List[int]]: - if -1 in suppress_tokens: - suppress_tokens = [t for t in suppress_tokens if t >= 0] - suppress_tokens.extend(tokenizer.non_speech_tokens) - elif suppress_tokens is None or len(suppress_tokens) == 0: - suppress_tokens = [] # interpret empty string as an empty list - else: - assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" - - suppress_tokens.extend( - [ - tokenizer.transcribe, - tokenizer.translate, - tokenizer.sot, - tokenizer.sot_prev, - tokenizer.sot_lm, - ] - ) - - return tuple(sorted(set(suppress_tokens))) - - def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: - # merge prepended punctuations i = len(alignment) - 2 j = len(alignment) - 1 while i >= 0: previous = alignment[i] following = alignment[j] if previous["word"].startswith(" ") and previous["word"].strip() in prepended: - # prepend it to the following word following["word"] = previous["word"] + following["word"] if "tokens" in alignment[0].keys(): following["tokens"] = previous["tokens"] + following["tokens"] @@ -2075,14 +1489,12 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> j = i i -= 1 - # merge appended punctuations i = 0 j = 1 while j < len(alignment): previous = alignment[i] following = alignment[j] if not previous["word"].endswith(" ") and following["word"] in appended: - # append it to the previous word previous["word"] = previous["word"] + following["word"] if "tokens" in alignment[0].keys(): previous["tokens"] = previous["tokens"] + following["tokens"]