Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2Vec2 upgrade with Conv1D options #1758

Merged
merged 15 commits into from
Aug 19, 2024
55 changes: 53 additions & 2 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,52 @@
namespace ctranslate2 {
namespace layers {

class Wav2Vec2LayerNormConvLayer : public Layer {
public:
Wav2Vec2LayerNormConvLayer(const models::Model& model,
const std::string& scope,
dim_t stride,
dim_t padding);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
dim_t _stride;
dim_t _padding;
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2PosConvLayer : public Layer {
public:
Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2Encoder : public Layer {
public:
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);
Expand Down Expand Up @@ -35,12 +81,17 @@ namespace ctranslate2 {
}

private:
const Wav2Vec2LayerNormConvLayer _feat_layer0;
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
const LayerNorm _fp_norm;
const Dense _fp_ff;
const Wav2Vec2PosConvLayer _pos_conv_embed;
const ops::Transpose _transpose;
const ops::GELU _gelu;
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
//const ops::Transpose _transpose;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
const Dense _lm_head;
};

}
Expand Down
45 changes: 38 additions & 7 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,8 @@ def architecture_name(self):
return "Wav2Vec2ForCTC"

def get_model_spec(self, model):
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
# that doesn't look available here so we make Wav2Vec2 encoder layers only
spec = wav2vec2_spec.Wav2Vec2Spec(
model.wav2vec2.config.num_feat_extract_layers,
model.wav2vec2.encoder.config.num_hidden_layers,
model.wav2vec2.encoder.config.num_attention_heads,
)
Expand All @@ -1007,9 +1006,7 @@ def get_model_spec(self, model):
layer.fc1 = layer.feed_forward.intermediate_dense
layer.fc2 = layer.feed_forward.output_dense

self.set_encoder(spec.encoder, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)
# only for Wav2Vec2Spec.get_vocabulary_size()
self.set_encoder(spec.encoder, model, model.wav2vec2.config)
return spec

def set_config(self, config, model, tokenizer):
Expand All @@ -1021,8 +1018,42 @@ def get_vocabulary(self, model, tokenizer):
def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_encoder(self, spec, encoder):
super().set_encoder(spec, encoder)
def set_feature_extractor(self, spec, feature_extractor):
spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
self.set_layer_norm(
spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm
)
for spec_layer, module_layer in zip(
spec.feat_layer, feature_extractor.conv_layers[1:]
):
spec_layer.conv.weight = module_layer.conv.weight
spec_layer.conv.bias = module_layer.conv.bias
self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)

def set_feature_projection(self, spec, feature_projection):
self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
self.set_linear(spec.fp_projection, feature_projection.projection)

def set_pos_conv_embed(self, spec, encoder, config):
# forcing parameters to be set because some transformers version initializes garbage numbers
# conv parameters are float16 so force float32 for the loading
encoder.pos_conv_embed.conv.weight.data = (
encoder.pos_conv_embed.conv.weight.data.float()
)
encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
for param in encoder.pos_conv_embed.parameters():
param.data = param.data.float()
encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size)))
spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight
spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias

def set_encoder(self, spec, model, config):
self.set_feature_extractor(spec, model.wav2vec2.feature_extractor)
self.set_feature_projection(spec, model.wav2vec2.feature_projection)
self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
super().set_encoder(spec, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)

def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)
Expand Down
27 changes: 21 additions & 6 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ def __init__(self):


class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
self.lm_head = common_spec.LinearSpec()
self.encoder = Wav2Vec2EncoderSpec(feat_layers, num_layers, num_heads)

@property
def name(self):
Expand All @@ -30,14 +29,30 @@ def get_default_config(self):
return Wav2Vec2Config()

def get_vocabulary_size(self):
return self.lm_head.weight.shape[0]
return self.encoder.lm_head.weight.shape[0]


class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()
self.layer_norm = common_spec.LayerNormSpec()


class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()


class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
self.feat_layer = [Wav2Vec2LayerNormConvLayer() for i in range(feat_layers - 1)]
self.fp_layer_norm = common_spec.LayerNormSpec()
self.fp_projection = common_spec.LinearSpec()
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]
self.lm_head = common_spec.LinearSpec()
89 changes: 17 additions & 72 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,24 +979,16 @@ def test_transformers_wav2vec2(
)
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)
# 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved

w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")
w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor")
processor = transformers.AutoProcessor.from_pretrained(
output_dir + "/wav2vec2_processor"
)

device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
output_dir + "/wav2vec2_partial.bin"
).to(device)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
model = ctranslate2.models.Wav2Vec2(
output_dir,
device=device,
device_index=[0],
Expand All @@ -1008,73 +1000,26 @@ def test_transformers_wav2vec2(
speech_array = np.load(
os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy")
)
input_values = w2v2_processor(
input_values = processor(
speech_array,
padding=True,
return_tensors="pt",
sampling_rate=16000,
).input_values

with torch.no_grad():
extract_features = w2v2_model.wav2vec2.feature_extractor(
input_values.to(w2v2_model.device)
).transpose(1, 2)
hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection(
extract_features
)
position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed(
hidden_states
)
hidden_states = position_embeddings + hidden_states
# hidden_states = w2v2_model.encoder.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

if ct2_w2v2_model.device == "cuda":
hidden_states = hidden_states.cpu()
else:
hidden_states.numpy()

hidden_states = np.ascontiguousarray(hidden_states)
hidden_states = np.ascontiguousarray(input_values.unsqueeze(0))
hidden_states = ctranslate2.StorageView.from_array(hidden_states)
to_cpu = (
ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1
)
ct2_output = ct2_w2v2_model.encode(
hidden_states,
to_cpu=to_cpu,
) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed
if ct2_w2v2_model.device == "cuda":
hidden_states = torch.as_tensor(
ct2_output,
device=ct2_w2v2_model.device,
)
to_cpu = model.device == "cuda" and len(model.device_index) > 1
output = model.encode(hidden_states, to_cpu=to_cpu)
if model.device == "cuda":
logits = torch.as_tensor(output, device=model.device)[0]
else:
hidden_states = torch.as_tensor(
np.array(ct2_output),
dtype=torch.float32,
device=ct2_w2v2_model.device,
)

encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=None,
attentions=None,
)
hidden_states = encoder_outputs[0]
outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
hidden_states = outputs[0]
# hidden_states = w2v2_model.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

with torch.no_grad():
logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0]
logits = torch.as_tensor(
np.array(output), dtype=torch.float32, device=model.device
)[0]

predicted_ids = torch.argmax(logits, dim=-1)
transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True)
transcription = processor.decode(predicted_ids, output_word_offsets=True)
transcription = transcription[0].replace(processor.tokenizer.unk_token, "")

assert transcription[0] == expected_transcription[0]
assert transcription == expected_transcription[0]
Loading
Loading