Skip to content

Commit

Permalink
Fix bucket refilling in _score methode of Inference class (#2557)
Browse files Browse the repository at this point in the history
* fixed score results overriding
* fixed bucket refilling in translator._score
  • Loading branch information
l-k-11235 authored Jan 30, 2024
1 parent 1c27987 commit 43c3300
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,29 +576,35 @@ def _process_bucket(bucket_translations):

def _score(self, infer_iter):
self.with_scores = True
scored_bucket = {}
score_res = []
processed_bucket = {}
prev_bucket_idx = 0
for batch, bucket_idx in infer_iter:
if bucket_idx != prev_bucket_idx:
prev_bucket_idx += 1
score_res += [item for _, item in sorted(processed_bucket.items())]
processed_bucket = {}
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
batch_inds_in_bucket = batch["ind_in_bucket"]
if self.return_gold_log_probs:
batch_gold_log_probs = (
batch_data["gold_log_probs"].cpu().numpy().tolist()
)
else:
batch_gold_log_probs = None
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
batch_inds_in_bucket = batch["ind_in_bucket"]
for i, _score in enumerate(batch_gold_scores):
log_probs = (
batch_gold_log_probs[i] if self.return_gold_log_probs else None
)
scored_bucket[batch_inds_in_bucket[i]] = (
_score,
log_probs,
batch_gold_log_probs = [
None for i, _ in enumerate(batch_inds_in_bucket)
]
for i, ind in enumerate(batch_inds_in_bucket):
processed_bucket[ind] = [
batch_gold_scores[i],
batch_gold_log_probs[i],
batch_tgt_lengths[i],
)
score_results = [scored_bucket[i] for i in range(len(scored_bucket))]
return score_results
]
if processed_bucket:
score_res += [item for _, item in sorted(processed_bucket.items())]
return score_res

def _align_pad_prediction(self, predictions, bos, pad):
"""
Expand Down

0 comments on commit 43c3300

Please sign in to comment.