diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index ccd9ad417..32305e86a 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -24,6 +24,8 @@ def check_opt(opt, num_source_embeddings): activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu") feat_merge = getattr(opt, "feat_merge", "concat") self_attn_type = getattr(opt, "self_attn_type", "scaled-dot") + if self_attn_type == "scaled-dot-flash": + self_attn_type = "scaled-dot" check = utils.ConfigurationChecker() check( @@ -60,8 +62,20 @@ def _get_model_spec_seq2seq( ): """Creates a model specification from the model options.""" with_relative_position = getattr(opt, "max_relative_positions", 0) > 0 + with_rotary = getattr(opt, "max_relative_positions", 0) == -1 + with_alibi = getattr(opt, "max_relative_positions", 0) == -2 activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu") + num_heads = getattr(opt, "heads", 8) + num_kv = getattr(opt, "num_kv", 0) + if num_kv == num_heads or num_kv == 0: + num_kv = None + rotary_dim = 0 if with_rotary else None + rotary_interleave = getattr(opt, "rotary_interleave", True) + ffn_glu = (activation_fn == "silu") or (activation_fn == "gated-gelu") + sliding_window = getattr(opt, "sliding_window", 0) + feat_merge = getattr(opt, "feat_merge", "concat") + layer_norm = getattr(opt, "layer_norm", "standard") # Return the first head of the last layer unless the model was trained with alignments. if getattr(opt, "lambda_align", 0) == 0: @@ -71,20 +85,26 @@ def _get_model_spec_seq2seq( alignment_layer = opt.alignment_layer alignment_heads = opt.alignment_heads - num_heads = getattr(opt, "heads", 8) - model_spec = transformer_spec.TransformerSpec.from_config( (opt.enc_layers, opt.dec_layers), num_heads, - with_relative_position=with_relative_position, activation=_SUPPORTED_ACTIVATIONS[activation_fn], + ffn_glu=ffn_glu, + with_relative_position=with_relative_position, + alibi=with_alibi, + rms_norm=layer_norm == "rms", + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + multi_query_attention=getattr(opt, "multiquery", False), + num_heads_kv=num_kv, + sliding_window=sliding_window, alignment_layer=alignment_layer, alignment_heads=alignment_heads, num_source_embeddings=num_source_embeddings, embeddings_merge=_SUPPORTED_FEATURES_MERGE[feat_merge], - multi_query_attention=getattr(opt, "multiquery", False), ) + model_spec.config.layer_norm_epsilon = getattr(opt, "norm_eps", 1e-6) model_spec.config.decoder_start_token = getattr(opt, "decoder_start_token", "") set_transformer_spec(model_spec, variables) diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index c3f8d91be..b21642942 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -22,7 +22,20 @@ def __init__( relative_attention_bias: bool = False, ffn_glu: bool = False, rms_norm: bool = False, + alibi: bool = False, + alibi_use_positive_positions: bool = False, + scale_alibi: bool = False, + rotary_dim: Optional[int] = None, + rotary_interleave: bool = True, + rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, + rotary_scaling_factor: float = 1, + rotary_base: float = 10000, + parallel_residual: bool = False, + shared_layer_norm: bool = False, multi_query_attention: bool = False, + num_heads_kv: Optional[int] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, ): """Initializes a Transformer encoder specification. @@ -43,9 +56,30 @@ def __init__( ffn_glu: Use gated linear units in the FFN layers as described in https://arxiv.org/abs/2002.05202. rms_norm: Use the root mean square layer normalization. - multi_query_attention: Use multi-query attention. + alibi: Use attention with linear biases. + alibi_use_positive_positions: Use positive positions in the ALiBi definition. + scale_alibi: Apply the dot product scale factor to ALiBi. + rotary_dim: Apply rotary embeddings to these first N dimensions. If 0, rotary + embeddings are applied to all dimensions. + rotary_interleave: Interleave the head dimensions when rotary embeddings are applied. + Otherwise the head dimensions are sliced in half. + rotary_scaling_type: Type of RoPE scaling. + rotary_scaling_factor: Factor used in the RoPE scaling. + rotary_base: The base period of the rotary embeddings. + parallel_residual: Use parallel residual connections in each layer block, as used + by the GPT-J and GPT-NeoX models. + shared_layer_norm: When using parallel residual, share the input and post + attention layer norms. + multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). + num_heads_kv: Number of attention heads for the key and value. + sliding_window: Max sequence length to retain in KV Cache. """ - self.multi_query_attention = multi_query_attention + if multi_query_attention: + if num_heads_kv is not None and num_heads_kv != 1: + raise ValueError( + "Enabling multi_query_attention implies num_heads_kv=1" + ) + num_heads_kv = 1 self.num_heads = np.dtype("int16").type(num_heads) self.pre_norm = pre_norm self.activation = np.dtype("int8").type(activation) @@ -54,7 +88,17 @@ def __init__( common_spec.EmbeddingsSpec() for _ in range(num_source_embeddings) ] self.scale_embeddings = True - if not relative_position and not relative_attention_bias: + self.alibi = alibi + self.alibi_use_positive_positions = alibi_use_positive_positions + self.scale_alibi = scale_alibi + if sliding_window is not None: + self.sliding_window = np.dtype("int32").type(sliding_window) + if ( + not relative_position + and not relative_attention_bias + and not alibi + and rotary_dim is None + ): self.position_encodings = PositionEncoderSpec() if pre_norm and not no_final_norm: self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) @@ -66,10 +110,22 @@ def __init__( relative_attention_bias=relative_attention_bias, ffn_glu=ffn_glu, rms_norm=rms_norm, - num_heads_kv=1 if multi_query_attention else None, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, + parallel_residual=parallel_residual, + shared_layer_norm=shared_layer_norm, + num_heads_kv=num_heads_kv, + head_dim=head_dim, + sliding_window=sliding_window, ) for _ in range(num_layers) ] + self.multi_query_attention = multi_query_attention or ( + num_heads_kv != num_heads + ) class TransformerDecoderSpec(model_spec.LayerSpec): @@ -224,7 +280,15 @@ def __init__( relative_attention_bias=False, ffn_glu=False, rms_norm=False, + rotary_dim=None, + rotary_interleave=True, + rotary_scaling_type=None, + rotary_scaling_factor=1, + rotary_base=10000, + parallel_residual=False, + shared_layer_norm=False, num_heads_kv=None, + head_dim=None, sliding_window=None, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( @@ -232,7 +296,13 @@ def __init__( relative_position=relative_position, relative_attention_bias=relative_attention_bias, rms_norm=rms_norm, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, num_heads_kv=num_heads_kv, + head_dim=head_dim, sliding_window=sliding_window, ) self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) @@ -364,7 +434,20 @@ def from_config( relative_attention_bias: bool = False, ffn_glu: bool = False, rms_norm: bool = False, + alibi: bool = False, + alibi_use_positive_positions: bool = False, + scale_alibi: bool = False, + rotary_dim: Optional[int] = None, + rotary_interleave: bool = True, + rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, + rotary_scaling_factor: float = 1, + rotary_base: float = 10000, + parallel_residual: bool = False, + shared_layer_norm: bool = False, multi_query_attention: bool = False, + num_heads_kv: Optional[int] = None, + head_dim: Optional[int] = None, + sliding_window: Optional[int] = None, ): """Creates a Transformer model specification. @@ -408,7 +491,20 @@ def from_config( relative_attention_bias=relative_attention_bias, ffn_glu=ffn_glu, rms_norm=rms_norm, + alibi=alibi, + alibi_use_positive_positions=alibi_use_positive_positions, + scale_alibi=scale_alibi, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, + parallel_residual=parallel_residual, + shared_layer_norm=shared_layer_norm, multi_query_attention=multi_query_attention, + num_heads_kv=num_heads_kv, + head_dim=head_dim, + sliding_window=sliding_window, ) decoder = TransformerDecoderSpec( @@ -424,7 +520,20 @@ def from_config( alignment_heads=alignment_heads, ffn_glu=ffn_glu, rms_norm=rms_norm, + alibi=alibi, + alibi_use_positive_positions=alibi_use_positive_positions, + scale_alibi=scale_alibi, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, + parallel_residual=parallel_residual, + shared_layer_norm=shared_layer_norm, multi_query_attention=multi_query_attention, + num_heads_kv=num_heads_kv, + head_dim=head_dim, + sliding_window=sliding_window, ) return cls(encoder, decoder)