Skip to content

Commit

Permalink
Changing lang tag, removing norm_order, setting to pre-LN.
Browse files Browse the repository at this point in the history
  • Loading branch information
kauterry committed Sep 11, 2023
1 parent 0eb2328 commit c52ce3a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 32 deletions.
22 changes: 9 additions & 13 deletions src/fairseq2/models/mbart/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,12 @@ class mBartConfig:
pos_encoder_type: Literal["sinusoidal", "learned"]
"""The type of position encoder."""

frontend_layernorm: bool
"""Whether to add the layernorm in the encoder, decoder frontend."""
layer_norm_embed: bool
"""Adds a layernorm to the embedding in the Transformer encoder."""

dropout_p: float
"""The dropout probability in Transformer layers."""

norm_order: TransformerNormOrder
"""The Layer Normalization order."""

def update_vocabulary(self, info: VocabularyInfo) -> None:
"""Update vocabulary configuration from ``info``."""
self.vocabulary_size, self.pad_idx = info.size, info.pad_idx
Expand All @@ -107,9 +104,8 @@ def _base() -> mBartConfig:
num_decoder_attn_heads=16,
ffn_inner_dim=4096,
pos_encoder_type="learned",
frontend_layernorm=True,
layer_norm_embed=True,
dropout_p=0.1,
norm_order=TransformerNormOrder.POST,
)


Expand Down Expand Up @@ -190,7 +186,7 @@ def build_frontend(self, embed: Embedding) -> TransformerFrontend:
return TransformerEmbeddingFrontend(
embed,
pos_encoder,
layer_norm=self.config.frontend_layernorm,
layer_norm=self.config.layer_norm_embed,
dropout_p=self.config.dropout_p,
device=self.device,
dtype=self.dtype,
Expand All @@ -204,7 +200,7 @@ def build_encoder(self) -> TransformerEncoder:

return StandardTransformerEncoder(
layers,
norm_order=self.config.norm_order,
norm_order=TransformerNormOrder.PRE,
device=self.device,
dtype=self.dtype,
)
Expand All @@ -217,7 +213,7 @@ def build_decoder(self) -> TransformerDecoder:

return StandardTransformerDecoder(
layers,
norm_order=self.config.norm_order,
norm_order=TransformerNormOrder.PRE,
device=self.device,
dtype=self.dtype,
)
Expand All @@ -232,7 +228,7 @@ def build_encoder_layer(self) -> TransformerEncoderLayer:
self_attn,
ffn,
dropout_p=self.config.dropout_p,
norm_order=self.config.norm_order,
norm_order=TransformerNormOrder.PRE,
device=self.device,
dtype=self.dtype,
)
Expand All @@ -250,7 +246,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer:
encoder_decoder_attn,
ffn,
dropout_p=self.config.dropout_p,
norm_order=self.config.norm_order,
norm_order=TransformerNormOrder.PRE,
device=self.device,
dtype=self.dtype,
)
Expand All @@ -272,7 +268,7 @@ def build_ffn(self) -> FeedForwardNetwork:
return StandardFeedForwardNetwork(
self.config.model_dim,
self.config.ffn_inner_dim,
norm_order=self.config.norm_order,
norm_order=TransformerNormOrder.PRE,
device=self.device,
dtype=self.dtype,
)
Expand Down
15 changes: 0 additions & 15 deletions src/fairseq2/models/mbart/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Any, Dict, Mapping, Union, final

import torch

from fairseq2.assets import (
AssetCard,
AssetDownloadManager,
Expand Down Expand Up @@ -46,24 +44,11 @@ def _upgrade_checkpoint(

embeds = state_dict["final_proj.weight"]

# fairseq had a bug that accidentally introduced a dummy token in the
# embedding table of NLLB-100. We just discard it.
if embeds.size(0) == 256103: # means NLLB-100
embeds = embeds[:-1]

state_dict["final_proj.weight"] = embeds

# fairseq checkpoints have duplicate embedding weights. Ensure that we
# use a single embedding table in fairseq2.
state_dict["encoder_frontend.embed.weight"] = embeds
state_dict["decoder_frontend.embed.weight"] = embeds

# The embedding positions of the control symbols in fairseq's dict do
# not match the SentencePiece model of the tokenizer.
with torch.inference_mode():
# (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS)
embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]]

return checkpoint

@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions src/fairseq2/models/mbart/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
:param default_lang:
The fall-back language if no language is specified.
"""
# Each language is represented by a `[lang_XX]` control symbol.
control_symbols = [f"[{lang}_XX]" for lang in langs]
# Each language is represented by a `[lang]` control symbol.
control_symbols = [f"[{lang}]" for lang in langs]

control_symbols.append("<mask>")

Expand Down Expand Up @@ -91,9 +91,9 @@ def create_encoder(

if mode is None or mode == "source":
prefix_tokens = ["<s>"]
suffix_tokens = ["</s>", f"[{lang}_XX]"]
suffix_tokens = ["</s>", f"[{lang}]"]
elif mode == "target":
prefix_tokens = [f"[{lang}_XX]", "<s>"]
prefix_tokens = [f"[{lang}]", "<s>"]
suffix_tokens = ["</s>"]
else:
raise ValueError(
Expand Down

0 comments on commit c52ce3a

Please sign in to comment.