From 5159595b188fa7995ade948a82dd9cf969848cc4 Mon Sep 17 00:00:00 2001 From: vince62s Date: Mon, 29 Apr 2024 18:36:53 +0200 Subject: [PATCH] fix test --- python/ctranslate2/converters/opennmt_py.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index b7f45c97b..32305e86a 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -71,10 +71,11 @@ def _get_model_spec_seq2seq( num_kv = None rotary_dim = 0 if with_rotary else None rotary_interleave = getattr(opt, "rotary_interleave", True) - ffn_glu = activation_fn == "silu" + ffn_glu = (activation_fn == "silu") or (activation_fn == "gated-gelu") sliding_window = getattr(opt, "sliding_window", 0) feat_merge = getattr(opt, "feat_merge", "concat") + layer_norm = getattr(opt, "layer_norm", "standard") # Return the first head of the last layer unless the model was trained with alignments. if getattr(opt, "lambda_align", 0) == 0: @@ -91,7 +92,7 @@ def _get_model_spec_seq2seq( ffn_glu=ffn_glu, with_relative_position=with_relative_position, alibi=with_alibi, - rms_norm=opt.layer_norm == "rms", + rms_norm=layer_norm == "rms", rotary_dim=rotary_dim, rotary_interleave=rotary_interleave, multi_query_attention=getattr(opt, "multiquery", False),