diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 7ae68d40..1f1970aa 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -1,7 +1,20 @@ +"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV + +The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional +system dependencies. FFmpeg does not need to be installed on the system. + +However, the API is quite low-level so we need to manipulate audio frames directly. +""" + +import gc +import io +import itertools + from typing import BinaryIO, Union +import av +import numpy as np import torch -import torchaudio def decode_audio( @@ -17,22 +30,83 @@ def decode_audio( split_stereo: Return separate left and right channels. Returns: - A float32 Torch Tensor. + A float32 Numpy array. If `split_stereo` is enabled, the function returns a 2-tuple with the separated left and right channels. """ + resampler = av.audio.resampler.AudioResampler( + format="s16", + layout="mono" if not split_stereo else "stereo", + rate=sampling_rate, + ) - waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T + raw_buffer = io.BytesIO() + dtype = None + + with av.open(input_file, mode="r", metadata_errors="ignore") as container: + frames = container.decode(audio=0) + frames = _ignore_invalid_frames(frames) + frames = _group_frames(frames, 500000) + frames = _resample_frames(frames, resampler) + + for frame in frames: + array = frame.to_ndarray() + dtype = array.dtype + raw_buffer.write(array) + + # It appears that some objects related to the resampler are not freed + # unless the garbage collector is manually run. + # https://github.com/SYSTRAN/faster-whisper/issues/390 + # note that this slows down loading the audio a little bit + # if that is a concern, please use ffmpeg directly as in here: + # https://github.com/openai/whisper/blob/25639fc/whisper/audio.py#L25-L62 + del resampler + gc.collect() + + audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) + + # Convert s16 back to f32. + audio = audio.astype(np.float32) / 32768.0 - if audio_sf != sampling_rate: - waveform = torchaudio.functional.resample( - waveform, orig_freq=audio_sf, new_freq=sampling_rate - ) if split_stereo: - return waveform[0], waveform[1] + left_channel = audio[0::2] + right_channel = audio[1::2] + return torch.from_numpy(left_channel), torch.from_numpy(right_channel) + + return torch.from_numpy(audio) + + +def _ignore_invalid_frames(frames): + iterator = iter(frames) + + while True: + try: + yield next(iterator) + except StopIteration: + break + except av.error.InvalidDataError: + continue + + +def _group_frames(frames, num_samples=None): + fifo = av.audio.fifo.AudioFifo() + + for frame in frames: + frame.pts = None # Ignore timestamp check. + fifo.write(frame) + + if num_samples is not None and fifo.samples >= num_samples: + yield fifo.read() + + if fifo.samples > 0: + yield fifo.read() + - return waveform.mean(0) +def _resample_frames(frames, resampler): + # Add None to flush the resampler. + for frame in itertools.chain(frames, [None]): + yield from resampler.resample(frame) def pad_or_trim(array, length: int, *, axis: int = -1): diff --git a/requirements.txt b/requirements.txt index e0a3afba..699d1609 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ tokenizers>=0.13,<1 onnxruntime>=1.14,<2 pyannote-audio>=3.1.1 torch>=2.1.1 -torchaudio>=2.1.2 +av>=11 tqdm \ No newline at end of file