diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index e4cf9043..c351c0a6 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -578,10 +578,9 @@ def generate_with_fallback( tokenizer: Tokenizer, options: TranscriptionOptions, ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: - result = None - avg_logprob = None - final_temperature = None - compression_ratio = None + decode_result = None + all_results = [] + below_cr_threshold_results = [] max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) @@ -601,7 +600,6 @@ def generate_with_fallback( "patience": options.patience, } - final_temperature = temperature result = self.model.generate( encoder_output, [prompt], @@ -625,20 +623,28 @@ def generate_with_fallback( text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) + decode_result = ( + result, + avg_logprob, + temperature, + compression_ratio, + ) + all_results.append(decode_result) + needs_fallback = False - if ( - options.compression_ratio_threshold is not None - and compression_ratio > options.compression_ratio_threshold - ): - needs_fallback = True # too repetitive + if options.compression_ratio_threshold is not None: + if compression_ratio > options.compression_ratio_threshold: + needs_fallback = True # too repetitive - self.logger.debug( - "Compression ratio threshold is not met with temperature %.1f (%f > %f)", - temperature, - compression_ratio, - options.compression_ratio_threshold, - ) + self.logger.debug( + "Compression ratio threshold is not met with temperature %.1f (%f > %f)", + temperature, + compression_ratio, + options.compression_ratio_threshold, + ) + else: + below_cr_threshold_results.append(decode_result) if ( options.log_prob_threshold is not None @@ -661,8 +667,13 @@ def generate_with_fallback( if not needs_fallback: break + else: + # all failed, select the result with the highest average log probability + decode_result = max( + below_cr_threshold_results or all_results, key=lambda x: x[1] + ) - return result, avg_logprob, final_temperature, compression_ratio + return decode_result def get_prompt( self,