diff --git a/src/fairseq2/models/mbart/builder.py b/src/fairseq2/models/mbart/builder.py index 59c3028a9..30b13c89c 100644 --- a/src/fairseq2/models/mbart/builder.py +++ b/src/fairseq2/models/mbart/builder.py @@ -74,15 +74,12 @@ class mBartConfig: pos_encoder_type: Literal["sinusoidal", "learned"] """The type of position encoder.""" - frontend_layernorm: bool - """Whether to add the layernorm in the encoder, decoder frontend.""" + layer_norm_embed: bool + """Adds a layernorm to the embedding in the Transformer encoder.""" dropout_p: float """The dropout probability in Transformer layers.""" - norm_order: TransformerNormOrder - """The Layer Normalization order.""" - def update_vocabulary(self, info: VocabularyInfo) -> None: """Update vocabulary configuration from ``info``.""" self.vocabulary_size, self.pad_idx = info.size, info.pad_idx @@ -107,9 +104,8 @@ def _base() -> mBartConfig: num_decoder_attn_heads=16, ffn_inner_dim=4096, pos_encoder_type="learned", - frontend_layernorm=True, + layer_norm_embed=True, dropout_p=0.1, - norm_order=TransformerNormOrder.POST, ) @@ -190,7 +186,7 @@ def build_frontend(self, embed: Embedding) -> TransformerFrontend: return TransformerEmbeddingFrontend( embed, pos_encoder, - layer_norm=self.config.frontend_layernorm, + layer_norm=self.config.layer_norm_embed, dropout_p=self.config.dropout_p, device=self.device, dtype=self.dtype, @@ -204,7 +200,7 @@ def build_encoder(self) -> TransformerEncoder: return StandardTransformerEncoder( layers, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -217,7 +213,7 @@ def build_decoder(self) -> TransformerDecoder: return StandardTransformerDecoder( layers, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -232,7 +228,7 @@ def build_encoder_layer(self) -> TransformerEncoderLayer: self_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -250,7 +246,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer: encoder_decoder_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -272,7 +268,7 @@ def build_ffn(self) -> FeedForwardNetwork: return StandardFeedForwardNetwork( self.config.model_dim, self.config.ffn_inner_dim, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) diff --git a/src/fairseq2/models/mbart/loader.py b/src/fairseq2/models/mbart/loader.py index 84cb94d3b..47653c065 100644 --- a/src/fairseq2/models/mbart/loader.py +++ b/src/fairseq2/models/mbart/loader.py @@ -6,8 +6,6 @@ from typing import Any, Dict, Mapping, Union, final -import torch - from fairseq2.assets import ( AssetCard, AssetDownloadManager, @@ -46,24 +44,11 @@ def _upgrade_checkpoint( embeds = state_dict["final_proj.weight"] - # fairseq had a bug that accidentally introduced a dummy token in the - # embedding table of NLLB-100. We just discard it. - if embeds.size(0) == 256103: # means NLLB-100 - embeds = embeds[:-1] - - state_dict["final_proj.weight"] = embeds - # fairseq checkpoints have duplicate embedding weights. Ensure that we # use a single embedding table in fairseq2. state_dict["encoder_frontend.embed.weight"] = embeds state_dict["decoder_frontend.embed.weight"] = embeds - # The embedding positions of the control symbols in fairseq's dict do - # not match the SentencePiece model of the tokenizer. - with torch.inference_mode(): - # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) - embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] - return checkpoint @staticmethod diff --git a/src/fairseq2/models/mbart/tokenizer.py b/src/fairseq2/models/mbart/tokenizer.py index 1a6e184f1..ac6087e18 100644 --- a/src/fairseq2/models/mbart/tokenizer.py +++ b/src/fairseq2/models/mbart/tokenizer.py @@ -38,8 +38,8 @@ def __init__( :param default_lang: The fall-back language if no language is specified. """ - # Each language is represented by a `[lang_XX]` control symbol. - control_symbols = [f"[{lang}_XX]" for lang in langs] + # Each language is represented by a `[lang]` control symbol. + control_symbols = [f"[{lang}]" for lang in langs] control_symbols.append("") @@ -91,9 +91,9 @@ def create_encoder( if mode is None or mode == "source": prefix_tokens = [""] - suffix_tokens = ["", f"[{lang}_XX]"] + suffix_tokens = ["", f"[{lang}]"] elif mode == "target": - prefix_tokens = [f"[{lang}_XX]", ""] + prefix_tokens = [f"[{lang}]", ""] suffix_tokens = [""] else: raise ValueError(