From c860004f78a151552d6c827665533ae7ba20ea02 Mon Sep 17 00:00:00 2001 From: archive-r Date: Mon, 17 Jul 2023 16:42:22 +0900 Subject: [PATCH 1/6] Resolve Inference Selection Bug --- faster_whisper/transcribe.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 8cb492de..38d29798 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -586,6 +586,7 @@ def generate_with_fallback( max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) ) + results = {} for temperature in options.temperatures: if temperature > 0: @@ -625,6 +626,13 @@ def generate_with_fallback( text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) + results[temperature] = ( + result, + avg_logprob, + final_temperature, + compression_ratio, + ) + needs_fallback = False if ( @@ -661,6 +669,9 @@ def generate_with_fallback( if not needs_fallback: break + else: + # all failed + return max(results.values(), key=lambda r: r[1]) return result, avg_logprob, final_temperature, compression_ratio From 0e58b3f18a98cc1a711f5dc3aa928a1bacfaf371 Mon Sep 17 00:00:00 2001 From: archive-r Date: Mon, 17 Jul 2023 20:16:09 +0900 Subject: [PATCH 2/6] Refactor for better readability --- faster_whisper/transcribe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 38d29798..ea0f4c40 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -671,7 +671,9 @@ def generate_with_fallback( break else: # all failed - return max(results.values(), key=lambda r: r[1]) + result, avg_logprob, final_temperature, compression_ratio = max( + results.values(), key=lambda r: r[1] + ) return result, avg_logprob, final_temperature, compression_ratio From ca286411f4cdfae140809f8ee12c45f0713082cc Mon Sep 17 00:00:00 2001 From: archive-r Date: Wed, 19 Jul 2023 19:17:45 +0900 Subject: [PATCH 3/6] Filter out results with compression_ratio --- faster_whisper/transcribe.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 9d96d566..05405a63 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -671,9 +671,30 @@ def generate_with_fallback( break else: # all failed - result, avg_logprob, final_temperature, compression_ratio = max( - results.values(), key=lambda r: r[1] - ) + filtered_results = {} + avg_logprob_index = 1 + compression_ratio_index = 3 + + # filter out results with compression ratio below compression_ratio_threshold + for key, result_data in results.items(): + if ( + options.no_speech_threshold is not None + and result_data[compression_ratio_index] + < options.compression_ratio_threshold + ): + filtered_results[key] = result_data + + # select the optimal result based on the maximum avg_logprob + if filtered_results: + result, avg_logprob, final_temperature, compression_ratio = max( + filtered_results.values(), + key=lambda result_data: result_data[avg_logprob_index], + ) + else: + result, avg_logprob, final_temperature, compression_ratio = max( + results.values(), + key=lambda result_data: result_data[avg_logprob_index], + ) return result, avg_logprob, final_temperature, compression_ratio From 30e0efe8a18d5d8800c7765a5104bf73ee571d29 Mon Sep 17 00:00:00 2001 From: archive-r Date: Thu, 20 Jul 2023 00:32:04 +0900 Subject: [PATCH 4/6] Refactor to avoid variable repetition --- faster_whisper/transcribe.py | 59 +++++++++++++----------------------- 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 05405a63..ba9ca713 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -586,7 +586,9 @@ def generate_with_fallback( max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) ) - results = {} + + all_results = [] + below_cr_threshold_results = [] for temperature in options.temperatures: if temperature > 0: @@ -626,27 +628,28 @@ def generate_with_fallback( text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) - results[temperature] = ( + decode_result = ( result, avg_logprob, final_temperature, compression_ratio, ) + all_results.append(decode_result) needs_fallback = False - if ( - options.compression_ratio_threshold is not None - and compression_ratio > options.compression_ratio_threshold - ): - needs_fallback = True # too repetitive + if options.compression_ratio_threshold is not None: + if compression_ratio > options.compression_ratio_threshold: + needs_fallback = True # too repetitive - self.logger.debug( - "Compression ratio threshold is not met with temperature %.1f (%f > %f)", - temperature, - compression_ratio, - options.compression_ratio_threshold, - ) + self.logger.debug( + "Compression ratio threshold is not met with temperature %.1f (%f > %f)", + temperature, + compression_ratio, + options.compression_ratio_threshold, + ) + else: + below_cr_threshold_results.append(decode_result) if ( options.log_prob_threshold is not None @@ -670,33 +673,13 @@ def generate_with_fallback( if not needs_fallback: break else: - # all failed - filtered_results = {} - avg_logprob_index = 1 - compression_ratio_index = 3 - - # filter out results with compression ratio below compression_ratio_threshold - for key, result_data in results.items(): - if ( - options.no_speech_threshold is not None - and result_data[compression_ratio_index] - < options.compression_ratio_threshold - ): - filtered_results[key] = result_data - - # select the optimal result based on the maximum avg_logprob - if filtered_results: - result, avg_logprob, final_temperature, compression_ratio = max( - filtered_results.values(), - key=lambda result_data: result_data[avg_logprob_index], - ) + # all failed, select the result with the highest log prob + if below_cr_threshold_results: + decode_result = max(below_cr_threshold_results, key=lambda x: x[3]) else: - result, avg_logprob, final_temperature, compression_ratio = max( - results.values(), - key=lambda result_data: result_data[avg_logprob_index], - ) + decode_result = max(all_results, key=lambda x: x[3]) - return result, avg_logprob, final_temperature, compression_ratio + return decode_result def get_prompt( self, From 19c867a21cfd54276574d120fff5a6502ab3e822 Mon Sep 17 00:00:00 2001 From: archive-r Date: Thu, 20 Jul 2023 01:23:14 +0900 Subject: [PATCH 5/6] Fix incorrect index and perform minor refactoring --- faster_whisper/transcribe.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index ba9ca713..1afa15ea 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -578,18 +578,14 @@ def generate_with_fallback( tokenizer: Tokenizer, options: TranscriptionOptions, ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: - result = None - avg_logprob = None - final_temperature = None - compression_ratio = None + decode_result = None + all_results = [] + below_cr_threshold_results = [] max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) ) - all_results = [] - below_cr_threshold_results = [] - for temperature in options.temperatures: if temperature > 0: kwargs = { @@ -673,11 +669,10 @@ def generate_with_fallback( if not needs_fallback: break else: - # all failed, select the result with the highest log prob - if below_cr_threshold_results: - decode_result = max(below_cr_threshold_results, key=lambda x: x[3]) - else: - decode_result = max(all_results, key=lambda x: x[3]) + # all failed, select the result with the highest average log probability + decode_result = max( + below_cr_threshold_results or all_results, key=lambda x: x[1] + ) return decode_result From a578a187b56369714a3575905a5e99c559d28f11 Mon Sep 17 00:00:00 2001 From: archive-r Date: Thu, 20 Jul 2023 14:38:34 +0900 Subject: [PATCH 6/6] Remove final_temperature variable --- faster_whisper/transcribe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 1afa15ea..c351c0a6 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -600,7 +600,6 @@ def generate_with_fallback( "patience": options.patience, } - final_temperature = temperature result = self.model.generate( encoder_output, [prompt], @@ -627,7 +626,7 @@ def generate_with_fallback( decode_result = ( result, avg_logprob, - final_temperature, + temperature, compression_ratio, ) all_results.append(decode_result)