Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-memory audio input mode #65

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions whisper_s2t/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self,
max_speech_len=29.0,
max_text_token_len=MAX_TEXT_TOKEN_LENGTH,
without_timestamps=True,
speech_segmenter_options={}):
speech_segmenter_options={},
file_io=True):

# Configure Params
self.device = device
Expand All @@ -73,6 +74,7 @@ def __init__(self,
tokenizer = NoneTokenizer()

self.tokenizer = tokenizer
self.file_io = file_io

self._init_dependables()

Expand All @@ -96,7 +98,8 @@ def _init_dependables(self):
max_speech_len=self.max_speech_len,
max_initial_prompt_len=self.max_initial_prompt_len,
use_dynamic_time_axis=self.use_dynamic_time_axis,
merge_chunks=self.merge_chunks
merge_chunks=self.merge_chunks,
file_io=self.file_io
)

def update_params(self, params={}):
Expand Down Expand Up @@ -181,4 +184,4 @@ def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_

pbar.update(pbar.total-pbar_pos)

return responses
return responses
11 changes: 6 additions & 5 deletions whisper_s2t/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ def __getitem__(self, item):

return audio, prompt, initial_prompt_tokens, seq_len


class WhisperDataLoader:
def __init__(self, device, tokenizer, speech_segmenter,
dta_padding=3.0,
without_timestamps=True,
max_speech_len=29.0,
max_initial_prompt_len=223,
merge_chunks=True,
use_dynamic_time_axis=False):
use_dynamic_time_axis=False,
file_io=True):

self.device = device
self.tokenizer = tokenizer
Expand All @@ -140,6 +140,7 @@ def __init__(self, device, tokenizer, speech_segmenter,
self.max_initial_prompt_len = max_initial_prompt_len
self.use_dynamic_time_axis = use_dynamic_time_axis
self.merge_chunks = merge_chunks
self.audio_gen = audio_batch_generator if file_io else lambda x: x

def data_collate_fn(self, batch):
if self.use_dynamic_time_axis:
Expand Down Expand Up @@ -209,7 +210,7 @@ def get_data_loader_with_vad(self, audio_files, lang_codes, tasks, initial_promp

segmented_audio_signal = []
pbar_update_len = {}
for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(audio_batch_generator(audio_files), lang_codes, tasks, initial_prompts)):
for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(self.audio_gen(audio_files), lang_codes, tasks, initial_prompts)):
start_ends, audio_signal = self.speech_segmenter(audio_signal=audio_signal)
new_segmented_audio_signal = self.get_segmented_audio_signal(start_ends, audio_signal, file_id, lang, task, initial_prompt)
pbar_update_len[file_id] = 1/len(new_segmented_audio_signal)
Expand Down Expand Up @@ -243,7 +244,7 @@ def get_data_loader(self, audio_files, lang_codes, tasks, initial_prompts, batch

segmented_audio_signal = []
pbar_update_len = {}
for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(audio_batch_generator(audio_files), lang_codes, tasks, initial_prompts)):
for file_id, (audio_signal, lang, task, initial_prompt) in enumerate(zip(self.audio_gen(audio_files), lang_codes, tasks, initial_prompts)):
start_ends, audio_signal = self.basic_segmenter(audio_signal=audio_signal)
new_segmented_audio_signal = self.get_segmented_audio_signal(start_ends, audio_signal, file_id, lang, task, initial_prompt)
pbar_update_len[file_id] = 1/len(new_segmented_audio_signal)
Expand All @@ -268,4 +269,4 @@ def __call__(self, audio_files, lang_codes, tasks, initial_prompts, batch_size=1
if use_vad:
return self.get_data_loader_with_vad(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size)
else:
return self.get_data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size)
return self.get_data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size)