diff --git a/include/ctranslate2/layers/wav2vec2.h b/include/ctranslate2/layers/wav2vec2.h index 4c25c941a..29dea9783 100644 --- a/include/ctranslate2/layers/wav2vec2.h +++ b/include/ctranslate2/layers/wav2vec2.h @@ -5,6 +5,52 @@ namespace ctranslate2 { namespace layers { + class Wav2Vec2LayerNormConvLayer : public Layer { + public: + Wav2Vec2LayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + dim_t _stride; + dim_t _padding; + const Conv1D _conv; + const LayerNorm _output_norm; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + + class Wav2Vec2PosConvLayer : public Layer { + public: + Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + const Conv1D _conv; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + class Wav2Vec2Encoder : public Layer { public: Wav2Vec2Encoder(const models::Model& model, const std::string& scope); @@ -35,12 +81,17 @@ namespace ctranslate2 { } private: + const Wav2Vec2LayerNormConvLayer _feat_layer0; + const std::vector> _feat_layers; + const LayerNorm _fp_norm; + const Dense _fp_ff; + const Wav2Vec2PosConvLayer _pos_conv_embed; + const ops::Transpose _transpose; const ops::GELU _gelu; - // wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported - //const ops::Transpose _transpose; const dim_t _num_heads; const std::vector> _layers; const LayerNorm _output_norm; + const Dense _lm_head; }; } diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index a6985b9d1..d98c65860 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -992,9 +992,8 @@ def architecture_name(self): return "Wav2Vec2ForCTC" def get_model_spec(self, model): - # Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16 - # that doesn't look available here so we make Wav2Vec2 encoder layers only spec = wav2vec2_spec.Wav2Vec2Spec( + model.wav2vec2.config.num_feat_extract_layers, model.wav2vec2.encoder.config.num_hidden_layers, model.wav2vec2.encoder.config.num_attention_heads, ) @@ -1007,9 +1006,7 @@ def get_model_spec(self, model): layer.fc1 = layer.feed_forward.intermediate_dense layer.fc2 = layer.feed_forward.output_dense - self.set_encoder(spec.encoder, model.wav2vec2.encoder) - self.set_linear(spec.lm_head, model.lm_head) - # only for Wav2Vec2Spec.get_vocabulary_size() + self.set_encoder(spec.encoder, model, model.wav2vec2.config) return spec def set_config(self, config, model, tokenizer): @@ -1021,8 +1018,42 @@ def get_vocabulary(self, model, tokenizer): def set_vocabulary(self, spec, tokens): spec.register_vocabulary(tokens) - def set_encoder(self, spec, encoder): - super().set_encoder(spec, encoder) + def set_feature_extractor(self, spec, feature_extractor): + spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight + spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias + self.set_layer_norm( + spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm + ) + for spec_layer, module_layer in zip( + spec.feat_layer, feature_extractor.conv_layers[1:] + ): + spec_layer.conv.weight = module_layer.conv.weight + spec_layer.conv.bias = module_layer.conv.bias + self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm) + + def set_feature_projection(self, spec, feature_projection): + self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm) + self.set_linear(spec.fp_projection, feature_projection.projection) + + def set_pos_conv_embed(self, spec, encoder, config): + # forcing parameters to be set because some transformers version initializes garbage numbers + # conv parameters are float16 so force float32 for the loading + encoder.pos_conv_embed.conv.weight.data = ( + encoder.pos_conv_embed.conv.weight.data.float() + ) + encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float() + for param in encoder.pos_conv_embed.parameters(): + param.data = param.data.float() + encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size))) + spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight + spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias + + def set_encoder(self, spec, model, config): + self.set_feature_extractor(spec, model.wav2vec2.feature_extractor) + self.set_feature_projection(spec, model.wav2vec2.feature_projection) + self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config) + super().set_encoder(spec, model.wav2vec2.encoder) + self.set_linear(spec.lm_head, model.lm_head) def set_common_layers(self, spec, module): self.set_layer_norm(spec.layer_norm, module.layer_norm) diff --git a/python/ctranslate2/specs/wav2vec2_spec.py b/python/ctranslate2/specs/wav2vec2_spec.py index 78b2ffa84..7b9b9cfe4 100644 --- a/python/ctranslate2/specs/wav2vec2_spec.py +++ b/python/ctranslate2/specs/wav2vec2_spec.py @@ -13,10 +13,9 @@ def __init__(self): class Wav2Vec2Spec(model_spec.LanguageModelSpec): - def __init__(self, num_layers, num_heads): + def __init__(self, feat_layers, num_layers, num_heads): super().__init__() - self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads) - self.lm_head = common_spec.LinearSpec() + self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads) @property def name(self): @@ -30,14 +29,30 @@ def get_default_config(self): return Wav2Vec2Config() def get_vocabulary_size(self): - return self.lm_head.weight.shape[0] + return self.encoder.lm_head.weight.shape[0] + + +class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() + self.layer_norm = common_spec.LayerNormSpec() + + +class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() class Wav2Vec2EncoderSpec(model_spec.LayerSpec): - def __init__(self, num_layers, num_heads): + def __init__(self, feat_layers, num_layers, num_heads): self.num_heads = np.dtype("int16").type(num_heads) - # wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported + self.feat_layer0 = Wav2Vec2LayerNormConvLayer() + self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)] + self.fp_layer_norm = common_spec.LayerNormSpec() + self.fp_projection = common_spec.LinearSpec() + self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer() self.layer_norm = common_spec.LayerNormSpec() self.layer = [ transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) ] + self.lm_head = common_spec.LinearSpec() diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index f27bd6ca2..3c35445fa 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -979,24 +979,16 @@ def test_transformers_wav2vec2( ) output_dir = str(tmp_dir.join("ctranslate2_model")) output_dir = converter.convert(output_dir) - # 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved - w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name) - del w2v2_model.wav2vec2.encoder.layers - del w2v2_model.wav2vec2.encoder.layer_norm - w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin") w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name) - torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin") + w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor") + processor = transformers.AutoProcessor.from_pretrained( + output_dir + "/wav2vec2_processor" + ) device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) - w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained( - output_dir + "/wav2vec2_partial.bin" - ).to(device) - del w2v2_model.wav2vec2.encoder.layers - del w2v2_model.wav2vec2.encoder.layer_norm - w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin") - ct2_w2v2_model = ctranslate2.models.Wav2Vec2( + model = ctranslate2.models.Wav2Vec2( output_dir, device=device, device_index=[0], @@ -1008,73 +1000,26 @@ def test_transformers_wav2vec2( speech_array = np.load( os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") ) - input_values = w2v2_processor( + input_values = processor( speech_array, padding=True, return_tensors="pt", sampling_rate=16000, ).input_values - with torch.no_grad(): - extract_features = w2v2_model.wav2vec2.feature_extractor( - input_values.to(w2v2_model.device) - ).transpose(1, 2) - hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection( - extract_features - ) - position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed( - hidden_states - ) - hidden_states = position_embeddings + hidden_states - # hidden_states = w2v2_model.encoder.dropout(hidden_states) - # Dropout(p=0.0, inplace=False) bypassed - - if ct2_w2v2_model.device == "cuda": - hidden_states = hidden_states.cpu() - else: - hidden_states.numpy() - - hidden_states = np.ascontiguousarray(hidden_states) + hidden_states = np.ascontiguousarray(input_values.unsqueeze(0)) hidden_states = ctranslate2.StorageView.from_array(hidden_states) - to_cpu = ( - ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1 - ) - ct2_output = ct2_w2v2_model.encode( - hidden_states, - to_cpu=to_cpu, - ) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed - if ct2_w2v2_model.device == "cuda": - hidden_states = torch.as_tensor( - ct2_output, - device=ct2_w2v2_model.device, - ) + to_cpu = model.device == "cuda" and len(model.device_index) > 1 + output = model.encode(hidden_states, to_cpu=to_cpu) + if model.device == "cuda": + logits = torch.as_tensor(output, device=model.device)[0] else: - hidden_states = torch.as_tensor( - np.array(ct2_output), - dtype=torch.float32, - device=ct2_w2v2_model.device, - ) - - encoder_outputs = transformers.modeling_outputs.BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=None, - attentions=None, - ) - hidden_states = encoder_outputs[0] - outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput( - last_hidden_state=hidden_states, - extract_features=extract_features, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - hidden_states = outputs[0] - # hidden_states = w2v2_model.dropout(hidden_states) - # Dropout(p=0.0, inplace=False) bypassed - - with torch.no_grad(): - logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0] + logits = torch.as_tensor( + np.array(output), dtype=torch.float32, device=model.device + )[0] predicted_ids = torch.argmax(logits, dim=-1) - transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True) + transcription = processor.decode(predicted_ids, output_word_offsets=True) + transcription = transcription[0].replace(processor.tokenizer.unk_token, "") - assert transcription[0] == expected_transcription[0] + assert transcription == expected_transcription[0] diff --git a/src/layers/wav2vec2.cc b/src/layers/wav2vec2.cc index 237c77fad..defbf0d84 100644 --- a/src/layers/wav2vec2.cc +++ b/src/layers/wav2vec2.cc @@ -3,14 +3,66 @@ namespace ctranslate2 { namespace layers { + + Wav2Vec2LayerNormConvLayer::Wav2Vec2LayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding) + : _stride(stride) + , _padding(padding) + , _conv(model, scope + "/conv", _stride, _padding) + , _transpose({0, 2, 1}) + , _output_norm(model, scope + "/layer_norm") { + } + + void Wav2Vec2LayerNormConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("Wav2Vec2LayerNormConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + buffer = std::move(input); + _conv(buffer, output); + _transpose(output, buffer); + _output_norm(buffer, output); + _transpose(output, buffer); + _gelu(buffer, output); + } + + Wav2Vec2PosConvLayer::Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope) + : _conv(model, scope + "/conv", /*stride=*/1, /*padding=*/64, /*dilation*/1, /*groups*/16) + , _transpose({0, 2, 1}) { + } + + void Wav2Vec2PosConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("Wav2Vec2PosConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + StorageView buffer2(input.dtype(), input.device()); + _transpose(input, buffer); + _conv(buffer, buffer2); + ops::Split(2, {buffer.dim(2), 1})(buffer2, buffer, output); + _gelu(buffer, buffer); + _transpose(buffer, buffer2); + ops::Add()(input, buffer2, output); + } + Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope) - : _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + : _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0) + , _feat_layers(build_layers_list(model, + scope + "/feat_layer", + /*stride=*/2, + /*padding=*/0)) + , _fp_norm(model, scope + "/fp_layer_norm") + , _fp_ff(model, scope + "/fp_projection", nullptr, true) + , _pos_conv_embed(model, scope + "/pos_conv_embed") + , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + , _transpose({0, 2, 1}) , _layers(build_layers_list(model, scope + "/layer", _num_heads, /*pre_norm=*/true, ops::ActivationType::GELU)) , _output_norm(model, scope + "/layer_norm") + , _lm_head(model, scope + "/lm_head", nullptr, true) { } @@ -18,40 +70,37 @@ namespace ctranslate2 { PROFILE("Wav2Vec2Encoder"); // SAD in front-end handles the input length - //const dim_t expected_depth = 1024; - //const dim_t expected_time = 406; - if (features.rank() != 3) throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + std::to_string(features.rank()) + " dimension(s) instead"); - /* //may need to limit the input lenght - if (features.dim(1) != expected_depth || features.dim(2) != expected_time) - throw std::invalid_argument("Invalid input features shape: expected an input with shape (" - + std::to_string(features.dim(0)) - + ", " - + std::to_string(expected_depth) - + ", " - + std::to_string(expected_time) - + "), but got an input with shape (" - + std::to_string(features.dim(0)) - + ", " - + std::to_string(features.dim(1)) - + ", " - + std::to_string(features.dim(2)) - + ") instead;; _conv1.output_size() " - + std::to_string(_conv1.output_size())); - //+ ") instead"); - */ - - StorageView input(output_type(), features.device()); - input = features; + + // Wav2Vec2FeatureExtractor------------------------------------ + StorageView feat_buffer(features.dtype(), features.device()); + StorageView feat_buffer2(features.dtype(), features.device()); + feat_buffer = std::move(features); + _feat_layer0(feat_buffer, output); + feat_buffer = std::move(output); + for (dim_t l = 0; l < _feat_layers.size(); l++) { + (*_feat_layers[l])(feat_buffer, output); + if (l < _feat_layers.size() - 1 ) { + feat_buffer = std::move(output); + } + } + _transpose(output, feat_buffer); + // Wav2Vec2FeatureProjection----------------------------------- + _fp_norm(feat_buffer, output); + _fp_ff(output, feat_buffer); + // Wav2Vec2PositionalConvEmbedding----------------------------- + _pos_conv_embed(feat_buffer, feat_buffer2); + // Wav2Vec2EncoderLayerStableLayerNorm------------------------- for (const auto& layer : _layers) { - (*layer)(input, nullptr, output); - input = std::move(output); + (*layer)(feat_buffer2, nullptr, feat_buffer); + feat_buffer2 = std::move(feat_buffer); } + _output_norm(feat_buffer2, feat_buffer); - _output_norm(input, output); + _lm_head(feat_buffer, output); } } diff --git a/src/models/wav2vec2.cc b/src/models/wav2vec2.cc index 79a7a40d4..7309f6eb6 100644 --- a/src/models/wav2vec2.cc +++ b/src/models/wav2vec2.cc @@ -35,8 +35,7 @@ namespace ctranslate2 { } bool Wav2Vec2Model::is_quantizable(const std::string& variable_name) const { - return (Model::is_quantizable(variable_name) - && variable_name.find("conv") == std::string::npos); + return Model::is_quantizable(variable_name); } bool Wav2Vec2Model::is_linear_weight(const std::string& variable_name) const {