Skip to content

Commit

Permalink
Refactor to avoid variable repetition
Browse files Browse the repository at this point in the history
  • Loading branch information
hoonlight committed Jul 19, 2023
1 parent ca28641 commit 30e0efe
Showing 1 changed file with 21 additions and 38 deletions.
59 changes: 21 additions & 38 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ def generate_with_fallback(
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
results = {}

all_results = []
below_cr_threshold_results = []

for temperature in options.temperatures:
if temperature > 0:
Expand Down Expand Up @@ -626,27 +628,28 @@ def generate_with_fallback(
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)

results[temperature] = (
decode_result = (
result,
avg_logprob,
final_temperature,
compression_ratio,
)
all_results.append(decode_result)

needs_fallback = False

if (
options.compression_ratio_threshold is not None
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive
if options.compression_ratio_threshold is not None:
if compression_ratio > options.compression_ratio_threshold:
needs_fallback = True # too repetitive

self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
else:
below_cr_threshold_results.append(decode_result)

if (
options.log_prob_threshold is not None
Expand All @@ -670,33 +673,13 @@ def generate_with_fallback(
if not needs_fallback:
break
else:
# all failed
filtered_results = {}
avg_logprob_index = 1
compression_ratio_index = 3

# filter out results with compression ratio below compression_ratio_threshold
for key, result_data in results.items():
if (
options.no_speech_threshold is not None
and result_data[compression_ratio_index]
< options.compression_ratio_threshold
):
filtered_results[key] = result_data

# select the optimal result based on the maximum avg_logprob
if filtered_results:
result, avg_logprob, final_temperature, compression_ratio = max(
filtered_results.values(),
key=lambda result_data: result_data[avg_logprob_index],
)
# all failed, select the result with the highest log prob
if below_cr_threshold_results:
decode_result = max(below_cr_threshold_results, key=lambda x: x[3])
else:
result, avg_logprob, final_temperature, compression_ratio = max(
results.values(),
key=lambda result_data: result_data[avg_logprob_index],
)
decode_result = max(all_results, key=lambda x: x[3])

return result, avg_logprob, final_temperature, compression_ratio
return decode_result

def get_prompt(
self,
Expand Down

0 comments on commit 30e0efe

Please sign in to comment.