diff --git a/onmt/tests/test_greedy_search.py b/onmt/tests/test_greedy_search.py index fd4291e506..d645740055 100644 --- a/onmt/tests/test_greedy_search.py +++ b/onmt/tests/test_greedy_search.py @@ -46,6 +46,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), min_length, @@ -100,6 +101,7 @@ def test_returns_correct_scores_deterministic(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -186,6 +188,7 @@ def test_returns_correct_scores_non_deterministic(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -297,6 +300,7 @@ def test_returns_correct_scores_non_deterministic_beams(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -374,7 +378,7 @@ def test_returns_correct_scores_non_deterministic_beams(self): samp.update_finished() self.assertEqual( - [score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]], + [score for score, _, _ in samp.hypotheses[batch_sz - 1][:1]], [valid_score_dist_2[0] / temp], ) @@ -419,6 +423,7 @@ def test_returns_correct_scores_non_deterministic_topp(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 8a5707ffa8..0cc58e4e37 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -98,6 +98,8 @@ class GreedySearch(DecodeStrategy): eos (int): See base. unk (int): See base. start (int): See base. + n_best (int): Don't stop until at least this many beams have + reached EOS. batch_size (int): See base. global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. @@ -123,6 +125,7 @@ def __init__( eos, unk, start, + n_best, batch_size, global_scorer, min_length, @@ -157,6 +160,7 @@ def __init__( self.keep_topp = keep_topp self.topk_scores = None self.beam_size = beam_size + self.n_best = n_best def initialize( self, enc_out, src_len, src_map=None, device=None, target_prefix=None @@ -265,10 +269,14 @@ def update_finished(self): else [] ) self.hypotheses[b_orig].append((score, pred, attention)) + if len(self.hypotheses[b_orig]) >= 2: + self.hypotheses[b_orig] = sorted( + self.hypotheses[b_orig], key=lambda x: x[0], reverse=True + ) self.done = self.is_finished.all() if self.done: for b in range(self.batch_size): - best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True) + best_hyp = self.hypotheses[b][: self.n_best] for score, pred, attn in best_hyp: self.scores[b].append(score) self.predictions[b].append(pred) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index be6903bcaf..99a835e1aa 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -810,6 +810,7 @@ def translate_batch(self, batch, attn_debug): eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, start=self._tgt_start_with, + n_best=self.n_best, batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length, @@ -1009,6 +1010,7 @@ def translate_batch(self, batch, attn_debug): eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, start=self._tgt_start_with, + n_best=self.n_best, batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length,