Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[large-v3] Error during transcription: Invalid input features shape: expected an input with shape (3, 80, 3000), but got an input with shape (3, 128, 3000) instead #51

Open
twardoch opened this issue Mar 20, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@twardoch
Copy link

Cell 1

!apt install ffmpeg
!pip install whisper-s2t yt-dlp gradio pydantic ffmpeg-python

Cell 2

import logging
from pathlib import Path
import whisper_s2t

from google.colab import drive

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration
class Config:
    model_identifier = "large-v3" # This causes a problem
    backend = "CTranslate2"
    output_format = "vtt"
    max_workers = 16
    path_root = "/content/drive"
    cwd = Path(path_root, "MyDrive/Colab Notebooks/YouTube Videos")


drive.mount(Config.path_root)

Cell 3

whisper_s2t_model = whisper_s2t.load_model(
    model_identifier=Config.model_identifier,
    backend=Config.backend,
    asr_options={"word_timestamps": True},
    # n_mels=128 # This doesn't matter
)

Cell 4

import asyncio
import os
import shutil
from concurrent.futures import ThreadPoolExecutor

import ffmpeg
import yt_dlp
from pydantic import BaseModel


# Pydantic model for VideoToTranscribe
class VideoToTranscribe(BaseModel):
    video_path: Path
    audio_path: Path
    metadata: dict | None = None
    lang_code: str = "en"
    initial_prompt: str | None = None
    vtt_path: Path


class VideoTranscriptor:
    def __init__(self, cwd: Path, whisper_s2t_model):
        self.cwd = cwd
        self.input_videos_dir = cwd / "input_videos"
        self.input_audios_dir = cwd / "input_audios"
        self.transcribed_dir = cwd / "transcribed"

        # Create directories if they don't exist
        self.input_videos_dir.mkdir(parents=True, exist_ok=True)
        self.input_audios_dir.mkdir(parents=True, exist_ok=True)
        self.transcribed_dir.mkdir(parents=True, exist_ok=True)
        self.whisper_s2t_model = whisper_s2t_model

    async def download_youtube_videos(self, url: str):
        ydl_opts = {
            "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            "outtmpl": str(self.input_videos_dir / "%(id)s.%(ext)s"),
        }

        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])

    async def transcribe_audio(
        self, audio_paths: list[Path], lang_codes: list[str], tasks: list[str]
    ):
        vtt_paths = [
            self.transcribed_dir / f"{audio_path.stem}.{Config.output_format}"
            for audio_path in audio_paths
        ]

        out = self.whisper_s2t_model.transcribe_with_vad(
            [str(audio_path) for audio_path in audio_paths],
            lang_codes=lang_codes,
            tasks=tasks,
            initial_prompts=[None] * len(audio_paths),
            batch_size=Config.max_workers,
        )

        whisper_s2t.write_outputs(
            out,
            format=Config.output_format,
            op_files=[str(vtt_path) for vtt_path in vtt_paths],
        )

        return vtt_paths

    async def process_videos(self, lang_code: str, output_lang_code: str):
        video_paths = list(self.input_videos_dir.glob("*.mp4"))

        def extract_audio(video_path: Path):
            audio_path = self.input_audios_dir / f"{video_path.stem}.wav"

            try:
                (
                    ffmpeg.input(str(video_path))
                    .output(str(audio_path), acodec="pcm_s16le", ar=16000, ac=1)
                    .overwrite_output()
                    .run(capture_stdout=True, capture_stderr=True)
                )
            except ffmpeg.Error as e:
                logger.error(
                    f"Error while extracting audio from {video_path}: {e.stderr.decode()}"
                )
                raise e

            return audio_path

        with ThreadPoolExecutor(max_workers=Config.max_workers) as executor:
            audio_extraction_tasks = [
                asyncio.get_event_loop().run_in_executor(
                    executor, extract_audio, video_path
                )
                for video_path in video_paths
            ]
            audio_paths = await asyncio.gather(*audio_extraction_tasks)

            task = "transcribe" if lang_code == output_lang_code else "translate"
            tasks = [task] * len(audio_paths)
            lang_codes = [output_lang_code] * len(audio_paths)

            vtt_paths = await self.transcribe_audio(audio_paths, lang_codes, tasks)

        videos_to_transcribe = [
            VideoToTranscribe(
                video_path=video_path, audio_path=audio_path, vtt_path=vtt_path
            )
            for video_path, audio_path, vtt_path in zip(
                video_paths, audio_paths, vtt_paths
            )
        ]

        return videos_to_transcribe

    async def cleanup(self, videos_to_transcribe: list[VideoToTranscribe]):
        for video in videos_to_transcribe:
            if video.vtt_path.exists():
                if video.video_path.exists():
                    shutil.move(str(video.video_path), str(self.transcribed_dir))
                    os.remove(str(video.audio_path))
                else:
                    shutil.move(str(video.audio_path), str(self.transcribed_dir))

    async def transcribe(self, youtube_url: str, lang_code: str, output_lang_code: str):
        if youtube_url:
            logger.info(f"Downloading YouTube video(s) from: {youtube_url}")
            await self.download_youtube_videos(youtube_url)

        logger.info("Processing videos...")
        videos_to_transcribe = await self.process_videos(lang_code, output_lang_code)

        logger.info("Cleaning up temporary files...")
        await self.cleanup(videos_to_transcribe)

        return f"Transcription completed. Files saved in {self.transcribed_dir}"

Cell 5

import gradio as gr

# Gradio UI


def launch_ui():
    cwd = Path(Config.cwd)
    transcriptor = VideoTranscriptor(cwd, whisper_s2t_model)

    async def transcribe_wrapper(
        youtube_url: str, lang_code: str, output_lang_code: str
    ):
        try:
            result = await transcriptor.transcribe(
                youtube_url, lang_code, output_lang_code
            )
            return result
        except Exception as e:
            logger.error(f"Error during transcription: {str(e)}")
            return f"An error occurred during transcription: {str(e)}"

    input_components = [
        gr.Textbox(
            label="YouTube URL (optional)",
            placeholder="Enter a YouTube video or playlist URL",
        ),
        gr.Textbox(label="Source Language Code", value="en"),
        gr.Textbox(label="Target Language Code", value="en"),
    ]

    iface = gr.Interface(
        fn=transcribe_wrapper,
        inputs=input_components,
        outputs="text",
        title="YouTube Video Transcriptor",
        description="Transcribe YouTube videos or local video files using Whisper",
        allow_flagging="never",
    )

    iface.launch(debug=False, share=True)


if __name__ == "__main__":
    launch_ui()

When trying to run this code with large-v3 model identifier, I keep getting:

ERROR:__main__:Error during transcription: Invalid input features shape: expected an input with shape (3, 80, 3000), but got an input with shape (3, 128, 3000) instead

With large-v2, it works fine.

@AmgadHasan
Copy link

Try uncommenting the n_mels line

whisper_s2t_model = whisper_s2t.load_model(
    model_identifier=Config.model_identifier,
    backend=Config.backend,
    asr_options={"word_timestamps": True},
    # n_mels=128 # This doesn't matter
)

@twardoch
Copy link
Author

By "this doesn't work" I meant: it fails if the parameter is commented or uncommented.

@shashikg
Copy link
Owner

@twardoch this is a bug for the aligner model. By default for alignment tiny model is used which expects n_mels to be of size 80 but large-v3 expects n_mels to be of size 128. Since same pre processor is getting shared, you are getting this issue.

I will fix this in next release.

Meanwhile for using large-v3 disable word timestamps (which should fix your issue):

asr_options={"word_timestamps": False},

@shashikg shashikg added the bug Something isn't working label Mar 22, 2024
@twardoch
Copy link
Author

Thanks! I do want them wordstamps though ;)

@aleksandr-smechov
Copy link

aleksandr-smechov commented Apr 4, 2024

@twardoch You can add a separate preprocessor with a fixed number of n_mels as shown in this commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants