diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index 3bf76a5f..cc208829 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -6,6 +6,10 @@ import tokenizers +class TokenizationError(Exception): + pass + + class Tokenizer: """Simple wrapper around a tokenizers.Tokenizer.""" @@ -87,23 +91,44 @@ def encode(self, text: str) -> List[int]: return self.tokenizer.encode(text, add_special_tokens=False).ids def decode(self, tokens: List[int]) -> str: - text_tokens = [token for token in tokens if token < self.eot] - return self.tokenizer.decode(text_tokens) + try: + text_tokens = [token for token in tokens if token < self.eot] + if not text_tokens: + return "" + if any(not isinstance(t, int) or t < 0 for t in text_tokens): + raise ValueError("Invalid token values detected") + return self.tokenizer.decode(text_tokens) + except Exception as e: + raise TokenizationError(f"Failed to decode tokens: {e}") from e def decode_with_timestamps(self, tokens: List[int]) -> str: - outputs = [[]] - - for token in tokens: - if token >= self.timestamp_begin: - timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" - outputs.append(timestamp) - outputs.append([]) - else: - outputs[-1].append(token) - - return "".join( - [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] - ) + try: + if not tokens: + raise ValueError("Empty token sequence") + if any(not isinstance(t, int) or t < 0 for t in tokens): + raise ValueError("Invalid token values detected") + + outputs = [[]] + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + + decoded = [ + s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs + ] + + if not any(decoded) and not any(isinstance(s, str) for s in outputs): + return "" + + return "".join(decoded) + except Exception as e: + raise TokenizationError( + f"Failed to decode tokens with timestamps: {e}" + ) from e @cached_property def non_speech_tokens(self) -> Tuple[int]: @@ -205,10 +230,7 @@ def split_tokens_on_spaces( return words, word_tokens -_TASKS = ( - "transcribe", - "translate", -) +_TASKS = ("transcribe", "translate") _LANGUAGE_CODES = ( "af",