From 1356ceb1726a50d057757256cec9b7f9d703e665 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Thu, 15 Apr 2021 13:40:30 -0500 Subject: [PATCH 01/31] Adds parameter to specify model behaviour. --- tensorflow_asr/models/conformer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 0fa3585ce4..978ec0b775 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -144,6 +144,7 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conv_module", **kwargs): super(ConvModule, self).__init__(name=name, **kwargs) @@ -157,7 +158,8 @@ def __init__(self, self.glu = GLU(name=f"{name}_glu") self.dw_conv = tf.keras.layers.DepthwiseConv2D( kernel_size=(kernel_size, 1), strides=1, - padding="same", name=f"{name}_dw_conv", + padding="same" if not streaming else "causal", + name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer @@ -218,6 +220,7 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_block", **kwargs): super(ConformerBlock, self).__init__(name=name, **kwargs) @@ -239,7 +242,8 @@ def __init__(self, dropout=dropout, name=f"{name}_conv_module", depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, + streaming=streaming ) self.ffm2 = FFModule( input_dim=input_dim, dropout=dropout, @@ -287,6 +291,7 @@ def __init__(self, dropout=0.0, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_encoder", **kwargs): super(ConformerEncoder, self).__init__(name=name, **kwargs) @@ -339,6 +344,7 @@ def __init__(self, depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + streaming=streaming, name=f"{name}_block_{i}" ) self.conformer_blocks.append(conformer_block) From 47eba7da7847463f84f60566cd248f9a203401f3 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Thu, 15 Apr 2021 13:41:22 -0500 Subject: [PATCH 02/31] Initial untested commit. Training script for Streaming Conformer Transducer. --- .../train_streaming_conformer.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/streaming_conformer/train_streaming_conformer.py diff --git a/examples/streaming_conformer/train_streaming_conformer.py b/examples/streaming_conformer/train_streaming_conformer.py new file mode 100644 index 0000000000..7a51015113 --- /dev/null +++ b/examples/streaming_conformer/train_streaming_conformer.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from tensorflow_asr.utils import setup_environment, setup_strategy + +setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") + +parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.runners.transducer_runners import TransducerTrainer +from tensorflow_asr.models.streaming_conformer import StreamingConformer + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) +text_featurizer = CharFeaturizer(config.decoder_config) + +if args.tfrecords: + train_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) + ) +else: + train_dataset = ASRSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) + ) + +streaming_transducer_trainer = TransducerTrainer( + config=config.learning_config.running_config, + text_featurizer=text_featurizer, strategy=strategy +) + +with streaming_transducer_trainer.strategy.scope(): + # build model + streaming_transducer = StreamingConformer( + **config.model_config, + vocabulary_size=text_featurizer.num_classes + ) + streaming_transducer._build(speech_featurizer.shape) + streaming_transducer.summary(line_length=150) + + optimizer_config = config.learning_config.optimizer_config + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=conformer.dmodel, + warmup_steps=optimizer_config["warmup_steps"], + max_lr=(0.05 / math.sqrt(conformer.dmodel)) + ), + beta_1=optimizer_config["beta1"], + beta_2=optimizer_config["beta2"], + epsilon=optimizer_config["epsilon"] + ) + +streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer, + max_to_keep=args.max_ckpts) + +streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) From 8d70d9ba9958b1db1776b244e741ff933b596df3 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Thu, 15 Apr 2021 14:09:27 -0500 Subject: [PATCH 03/31] Renames variables. --- .../train_streaming_conformer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/streaming_conformer/train_streaming_conformer.py b/examples/streaming_conformer/train_streaming_conformer.py index 7a51015113..14fea12944 100644 --- a/examples/streaming_conformer/train_streaming_conformer.py +++ b/examples/streaming_conformer/train_streaming_conformer.py @@ -75,19 +75,19 @@ **vars(config.learning_config.eval_dataset_config) ) -streaming_transducer_trainer = TransducerTrainer( +streaming_conformer_trainer = TransducerTrainer( config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) -with streaming_transducer_trainer.strategy.scope(): +with streaming_conformer_trainer.strategy.scope(): # build model - streaming_transducer = StreamingConformer( + streaming_conformer = StreamingConformer( **config.model_config, vocabulary_size=text_featurizer.num_classes ) - streaming_transducer._build(speech_featurizer.shape) - streaming_transducer.summary(line_length=150) + streaming_conformer._build(speech_featurizer.shape) + streaming_conformer.summary(line_length=150) optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( @@ -101,7 +101,7 @@ epsilon=optimizer_config["epsilon"] ) -streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer, +streaming_conformer_trainer.compile(model=streaming_conformer, optimizer=optimizer, max_to_keep=args.max_ckpts) -streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) +streaming_conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs) From ac216b2a4c46e6ed20e3593c33a54340a2a19fd2 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Fri, 16 Apr 2021 16:17:43 -0500 Subject: [PATCH 04/31] Adds ASRMaskedSliceDataset, for generating rolling masks for Streaming Conformer. --- tensorflow_asr/datasets/asr_dataset.py | 112 +++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index a8d8045680..ce0d2ff170 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -332,3 +332,115 @@ def create(self, batch_size: int): dataset = tf.data.Dataset.from_tensor_slices(self.entries) dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE) return self.process(dataset, batch_size) + + +class ASRMaskedSliceDataset(ASRSliceDataset): + """ Dataset for ASR with rolling mask """ + + def __init__(self, + stage: str, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + data_paths: list, + augmentations: Augmentation = Augmentation(None), + cache: bool = False, + shuffle: bool = False, + indefinite: bool = False, + drop_remainder: bool = True, + use_tf: bool = False, + buffer_size: int = BUFFER_SIZE, + history_window_size: int = 3, + input_chunk_duration: int = 250, + **kwargs): + super(ASRMaskedSliceDataset, self).__init__( + data_paths=data_paths, augmentations=augmentations, + cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size, + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite, + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer + ) + self.speech_featurizer = speech_featurizer + self.text_featurizer = text_featurizer + self.history_window_size = history_window_size + self.input_chunk_size = input_chunk_duration * self.speech_featurizer.sample_rate // 1000 + + def calculate_mask(self, num_frames): + frame_step = self.speech_featurizer.frame_step + frames_per_chunk = self.input_chunk_size // frame_step + + time_reduction_factor = 4 # TODO: Get time_reduction_factor from config or model. + num_frames = tf.cast(tf.math.ceil(num_frames / time_reduction_factor), tf.int32) + + def _calculate_mask(num_frames, frames_per_chunk, history_window_size): + mask = np.zeros((num_frames, num_frames), dtype=np.int32) + for i in range(num_frames): + # Frames in the same chunk can see each other + # If frames in `history_window_size` are in other chunks, the full chunks are visible + current_chunk_index = i // frames_per_chunk + history_chunk_index = (i - history_window_size) // frames_per_chunk + for curr in range(history_chunk_index, current_chunk_index + 1): + for j in range(frames_per_chunk): + base_index = curr * frames_per_chunk + if base_index + j < 0 or base_index + j >= num_frames: + continue + mask[i, base_index + j] = 1 + return mask + + return tf.numpy_function( + _calculate_mask, inp=[num_frames, frames_per_chunk, self.history_window_size], Tout=tf.int32 + ) + + def preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + preprocessed_inputs = super(ASRMaskedSliceDataset, self).preprocess(path, audio, indices) + + input_length = preprocessed_inputs[2] + mask = self.calculate_mask(input_length) + mask.set_shape((None, None)) + + return (*preprocessed_inputs, mask) + + def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + preprocessed_inputs = super(ASRMaskedSliceDataset, self).tf_preprocess(path, audio, indices) + + input_length = preprocessed_inputs[2] + mask = self.calculate_mask(input_length) + mask.set_shape((None, None)) + + return (*preprocessed_inputs, mask) + + # -------------------------------- CREATION ------------------------------------- + + def process(self, dataset: tf.data.Dataset, batch_size: int): + dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) + + if self.cache: + dataset = dataset.cache() + + if self.shuffle: + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + + if self.indefinite: + dataset = dataset.repeat() + + # PADDED BATCH the dataset + dataset = dataset.padded_batch( + batch_size=batch_size, + padded_shapes=( + tf.TensorShape([]), + tf.TensorShape(self.speech_featurizer.shape), + tf.TensorShape([]), + tf.TensorShape(self.text_featurizer.shape), + tf.TensorShape([]), + tf.TensorShape(self.text_featurizer.prepand_shape), + tf.TensorShape([]), + tf.TensorShape([self.speech_featurizer.shape[0], self.speech_featurizer.shape[0]]) + ), + padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0, 0), + drop_remainder=self.drop_remainder + ) + + # PREFETCH to improve speed of input length + dataset = dataset.prefetch(AUTOTUNE) + self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) + return dataset + +# TODO: Create masked TFRecords dataset From f4988fb13a5271f52760c750c66efe0177bce309 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Fri, 16 Apr 2021 16:20:45 -0500 Subject: [PATCH 05/31] Changes DepthwiseConv2D for SeparableConv1D. --- tensorflow_asr/models/conformer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow_asr/models/conformer.py b/tensorflow_asr/models/conformer.py index 978ec0b775..d47ffb7968 100755 --- a/tensorflow_asr/models/conformer.py +++ b/tensorflow_asr/models/conformer.py @@ -149,15 +149,16 @@ def __init__(self, **kwargs): super(ConvModule, self).__init__(name=name, **kwargs) self.ln = tf.keras.layers.LayerNormalization() - self.pw_conv_1 = tf.keras.layers.Conv2D( + self.pw_conv_1 = tf.keras.layers.Conv1D( filters=2 * input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) self.glu = GLU(name=f"{name}_glu") - self.dw_conv = tf.keras.layers.DepthwiseConv2D( - kernel_size=(kernel_size, 1), strides=1, + self.dw_conv = tf.keras.layers.SeparableConv1D( + filters=input_dim, + kernel_size=(kernel_size), strides=1, padding="same" if not streaming else "causal", name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, @@ -170,7 +171,7 @@ def __init__(self, beta_regularizer=bias_regularizer ) self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") - self.pw_conv_2 = tf.keras.layers.Conv2D( + self.pw_conv_2 = tf.keras.layers.Conv1D( filters=input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_2", kernel_regularizer=kernel_regularizer, @@ -182,7 +183,6 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) B, T, E = shape_list(outputs) - outputs = tf.reshape(outputs, [B, T, 1, E]) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) @@ -402,6 +402,7 @@ def __init__(self, joint_trainable: bool = True, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name: str = "conformer", **kwargs): super(Conformer, self).__init__( @@ -420,6 +421,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, + streaming=streaming, name=f"{name}_encoder" ), vocabulary_size=vocabulary_size, From a6bc25fc195cf5334da086d7b78af19892aaffeb Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Fri, 16 Apr 2021 16:24:49 -0500 Subject: [PATCH 06/31] Adds StreamingConformer class. Cleanup pending. --- .../models/streaming_conformer_transducer.py | 254 ++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 tensorflow_asr/models/streaming_conformer_transducer.py diff --git a/tensorflow_asr/models/streaming_conformer_transducer.py b/tensorflow_asr/models/streaming_conformer_transducer.py new file mode 100644 index 0000000000..ccc5646974 --- /dev/null +++ b/tensorflow_asr/models/streaming_conformer_transducer.py @@ -0,0 +1,254 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" http://arxiv.org/abs/1811.06621 """ + +import tensorflow as tf + +from .layers.subsampling import TimeReduction +# from .transducer import Transducer +from ..utils.utils import get_rnn, merge_two_last_dims, shape_list +from .conformer import Conformer + +L2 = tf.keras.regularizers.l2(1e-6) + +class StreamingConformer(Conformer): + """ + Attempt at implementing Streaming Conformer Transducer. (see: https://arxiv.org/pdf/2010.11395.pdf). + + Three main differences: + - Inputs are splits into chunks. + - Masking is used for MHSA to select the chunks to be used at each timestep. (Allows for parallel training.) + - Added parameter `streaming` to ConformerEncoder, ConformerBlock and ConvModule. Inside ConvModule, the layer DepthwiseConv2D has padding changed to "causal" when `streaming==True`. + + NOTE: Masking is applied just as regular masking along with the inputs. + """ + def __init__(self, + vocabulary_size: int, + encoder_subsampling: dict, + encoder_positional_encoding: str = "sinusoid", + encoder_dmodel: int = 144, + encoder_num_blocks: int = 16, + encoder_head_size: int = 36, + encoder_num_heads: int = 4, + encoder_mha_type: str = "relmha", + encoder_kernel_size: int = 32, + encoder_depth_multiplier: int = 1, + encoder_fc_factor: float = 0.5, + encoder_dropout: float = 0, + encoder_trainable: bool = True, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "streaming_conformer", + **kwargs): + + self.streaming = True # Hardcoded value. Initializes Conformer with `streaming = True`. + super(StreamingConformer, self).__init__( + vocabulary_size=vocabulary_size, + encoder_subsampling=encoder_subsampling, + encoder_positional_encoding=encoder_positional_encoding, + encoder_dmodel=encoder_dmodel, + encoder_num_blocks=encoder_num_blocks, + encoder_head_size=encoder_head_size, + encoder_num_heads=encoder_num_heads, + encoder_mha_type=encoder_mha_type, + encoder_depth_multiplier=encoder_depth_multiplier, + encoder_kernel_size=encoder_kernel_size, + encoder_fc_factor=encoder_fc_factor, + encoder_dropout=encoder_dropout, + encoder_trainable=encoder_trainable, + prediction_embed_dim=prediction_embed_dim, + prediction_embed_dropout=prediction_embed_dropout, + prediction_num_rnns=prediction_num_rnns, + prediction_rnn_units=prediction_rnn_units, + prediction_rnn_type=prediction_rnn_type, + prediction_rnn_implementation=prediction_rnn_implementation, + prediction_layer_norm=prediction_layer_norm, + prediction_projection_units=prediction_projection_units, + prediction_trainable=prediction_trainable, + joint_dim=joint_dim, + joint_activation=joint_activation, + prejoint_linear=prejoint_linear, + postjoint_linear=postjoint_linear, + joint_mode=joint_mode, + joint_trainable=joint_trainable, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + streaming=self.streaming, + name=name, + **kwargs + ) + self.dmodel = encoder_dmodel + self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor + + def _build(self, input_shape, prediction_shape=[None], batch_size=None): + inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + input_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + pred = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + pred_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + mask = tf.keras.Input(shape=[None, None], batch_size=batch_size, dtype=tf.int32) + self([inputs, input_length, pred, pred_length, mask], training=False) + + def call(self, inputs, training=False, **kwargs): + """ + Transducer Model call function + Args: + features: audio features in shape [B, T, F, C] + input_length: features time length in shape [B] + prediction: predicted sequence of ids, in shape [B, U] + prediction_length: predicted sequence of ids length in shape [B] + mask: mask for streamed input frames [B, T, T] + training: python boolean + **kwargs: sth else + + Returns: + `logits` with shape [B, T, U, vocab] + """ + features, _, prediction, prediction_length, mask = inputs + enc = self.encoder(features, training=training, mask=mask, **kwargs) # Passes mask to encoder + pred = self.predict_net([prediction, prediction_length], training=training, **kwargs) + outputs = self.joint_net([enc, pred], training=training, **kwargs) + return outputs + + def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): + """Infer function for encoder (or encoders) + + Args: + features (tf.Tensor): features with shape [T, F, C] + states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P] + + Returns: + tf.Tensor: output of encoders with shape [T, E] + tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P] + """ + with tf.name_scope(f"{self.name}_encoder"): + outputs = tf.expand_dims(features, axis=0) + outputs, new_states = self.encoder.recognize(outputs, states) + return tf.squeeze(outputs, axis=0), new_states + + # -------------------------------- GREEDY ------------------------------------- + + @tf.function + def recognize(self, + features: tf.Tensor, + input_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Greedy decoding + Args: + features (tf.Tensor): a batch of padded extracted features + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + batch_size, _, _, _ = shape_list(features) + encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size)) + return self._perform_greedy_batch(encoded, input_length, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): + """ + Function to convert to tflite using greedy decoding (default streaming mode) + Args: + signal: tf.Tensor with shape [None] indicating a single audio signal + predicted: last predicted character with shape [] + encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + + Return: + transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 + predicted: last predicted character with shape [] + encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] + prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] + """ + features = self.speech_featurizer.tf_extract(signal) + encoded, new_encoder_states = self.encoder_inference(features, encoder_states) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) + return transcript, hypothesis.index, new_encoder_states, hypothesis.states + + def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): + features = self.speech_featurizer.tf_extract(signal) + encoded, new_encoder_states = self.encoder_inference(features, encoder_states) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) + indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + non_blank = tf.where(tf.not_equal(upoints, 0)) + non_blank_transcript = tf.gather_nd(upoints, non_blank) + non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states + + # -------------------------------- BEAM SEARCH ------------------------------------- + + @tf.function + def recognize_beam(self, + features: tf.Tensor, + input_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True): + """ + RNN Transducer Beam Search + Args: + features (tf.Tensor): a batch of padded extracted features + lm (bool, optional): whether to use language model. Defaults to False. + + Returns: + tf.Tensor: a batch of decoded transcripts + """ + batch_size, _, _, _ = shape_list(features) + encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size)) + return self._perform_beam_search_batch(encoded, input_length, lm, + parallel_iterations=parallel_iterations, swap_memory=swap_memory) + + # -------------------------------- TFLITE ------------------------------------- + + def make_tflite_function(self, timestamp: bool = True): + tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite + return tf.function( + tflite_func, + input_signature=[ + tf.TensorSpec([None], dtype=tf.float32), + tf.TensorSpec([], dtype=tf.int32), + tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) + ] + ) From a4a2b18062ca2c9f4158a955f604b3d67e90f0c3 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Fri, 16 Apr 2021 16:27:37 -0500 Subject: [PATCH 07/31] Adds MaskedTransducerTrainer for training Streaming Conformer. --- .../runners/masked_transducer_runners.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tensorflow_asr/runners/masked_transducer_runners.py diff --git a/tensorflow_asr/runners/masked_transducer_runners.py b/tensorflow_asr/runners/masked_transducer_runners.py new file mode 100644 index 0000000000..86f23f3281 --- /dev/null +++ b/tensorflow_asr/runners/masked_transducer_runners.py @@ -0,0 +1,90 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tensorflow as tf + +from ..configs.config import RunningConfig +from ..optimizers.accumulation import GradientAccumulation +from .base_runners import BaseTrainer +from ..losses.rnnt_losses import rnnt_loss +from ..models.transducer import Transducer +from ..featurizers.text_featurizers import TextFeaturizer +from ..utils.utils import get_reduced_length + + +class MaskedTransducerTrainer(BaseTrainer): + def __init__(self, + config: RunningConfig, + text_featurizer: TextFeaturizer, + strategy: tf.distribute.Strategy = None): + self.text_featurizer = text_featurizer + super(MaskedTransducerTrainer, self).__init__(config, strategy=strategy) + + def set_train_metrics(self): + self.train_metrics = { + "transducer_loss": tf.keras.metrics.Mean("train_transducer_loss", dtype=tf.float32) + } + + def set_eval_metrics(self): + self.eval_metrics = { + "transducer_loss": tf.keras.metrics.Mean("eval_transducer_loss", dtype=tf.float32) + } + + def save_model_weights(self): + self.model.save_weights(os.path.join(self.config.outdir, "latest.h5")) + + @tf.function(experimental_relax_shapes=True) + def _train_step(self, batch): + _, features, input_length, labels, label_length, prediction, prediction_length, mask = batch + + with tf.GradientTape() as tape: + logits = self.model([features, input_length, prediction, prediction_length, mask], training=True) + tape.watch(logits) + per_train_loss = rnnt_loss( + logits=logits, labels=labels, label_length=label_length, + logit_length=get_reduced_length(input_length, self.model.time_reduction_factor), + blank=self.text_featurizer.blank + ) + train_loss = tf.nn.compute_average_loss(per_train_loss, + global_batch_size=self.global_batch_size) + + gradients = tape.gradient(train_loss, self.model.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) + + self.train_metrics["transducer_loss"].update_state(per_train_loss) + + @tf.function(experimental_relax_shapes=True) + def _eval_step(self, batch): + _, features, input_length, labels, label_length, prediction, prediction_length, mask = batch + + logits = self.model([features, input_length, prediction, prediction_length, mask], training=False) + eval_loss = rnnt_loss( + logits=logits, labels=labels, label_length=label_length, + logit_length=get_reduced_length(input_length, self.model.time_reduction_factor), + blank=self.text_featurizer.blank + ) + + self.eval_metrics["transducer_loss"].update_state(eval_loss) + + def compile(self, + model: Transducer, + optimizer: any, + max_to_keep: int = 10): + with self.strategy.scope(): + self.model = model + self.optimizer = tf.keras.optimizers.get(optimizer) + self.create_checkpoint_manager(max_to_keep, model=self.model, optimizer=self.optimizer) + +# TODO: Create MaskedTransducerGATrainer for working with gradient accumulation From f8491e6f5a82693e523481e5935868420fac97b8 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Fri, 16 Apr 2021 16:30:44 -0500 Subject: [PATCH 08/31] Configures for training Streaming Conformer. --- .../train_streaming_conformer.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/streaming_conformer/train_streaming_conformer.py b/examples/streaming_conformer/train_streaming_conformer.py index 14fea12944..fdd4686a78 100644 --- a/examples/streaming_conformer/train_streaming_conformer.py +++ b/examples/streaming_conformer/train_streaming_conformer.py @@ -13,11 +13,14 @@ # limitations under the License. import os +import math import argparse from tensorflow_asr.utils import setup_environment, setup_strategy setup_environment() import tensorflow as tf +physical_devices = tf.config.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(physical_devices[0], True) DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") @@ -46,11 +49,12 @@ strategy = setup_strategy(args.devices) from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset +from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.runners.transducer_runners import TransducerTrainer -from tensorflow_asr.models.streaming_conformer import StreamingConformer +from tensorflow_asr.runners.masked_transducer_runners import MaskedTransducerTrainer +from tensorflow_asr.models.streaming_conformer_transducer import StreamingConformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule config = Config(args.config) speech_featurizer = TFSpeechFeaturizer(config.speech_config) @@ -66,16 +70,16 @@ **vars(config.learning_config.eval_dataset_config) ) else: - train_dataset = ASRSliceDataset( + train_dataset = ASRMaskedSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config) ) - eval_dataset = ASRSliceDataset( + eval_dataset = ASRMaskedSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) -streaming_conformer_trainer = TransducerTrainer( +streaming_conformer_trainer = MaskedTransducerTrainer( config=config.learning_config.running_config, text_featurizer=text_featurizer, strategy=strategy ) @@ -92,9 +96,9 @@ optimizer_config = config.learning_config.optimizer_config optimizer = tf.keras.optimizers.Adam( TransformerSchedule( - d_model=conformer.dmodel, + d_model=streaming_conformer.dmodel, warmup_steps=optimizer_config["warmup_steps"], - max_lr=(0.05 / math.sqrt(conformer.dmodel)) + max_lr=(0.05 / math.sqrt(streaming_conformer.dmodel)) ), beta_1=optimizer_config["beta1"], beta_2=optimizer_config["beta2"], From 52ad653010420f25c3be22b21b9bea8b3472ab61 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 10:21:11 -0500 Subject: [PATCH 09/31] Adds streaming and changes Conv2D for Conv1D to refactored repo. --- tensorflow_asr/models/encoders/conformer.py | 22 ++++++++++++------- tensorflow_asr/models/transducer/conformer.py | 2 ++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index de7b767fdd..fb37dd93c9 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -143,20 +143,23 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conv_module", **kwargs): super(ConvModule, self).__init__(name=name, **kwargs) self.ln = tf.keras.layers.LayerNormalization() - self.pw_conv_1 = tf.keras.layers.Conv2D( + self.pw_conv_1 = tf.keras.layers.Conv1D( filters=2 * input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) self.glu = GLU(name=f"{name}_glu") - self.dw_conv = tf.keras.layers.DepthwiseConv2D( - kernel_size=(kernel_size, 1), strides=1, - padding="same", name=f"{name}_dw_conv", + self.dw_conv = tf.keras.layers.SeparableConv1D( + filters=input_dim, + kernel_size=(kernel_size), strides=1, + padding="same" if not streaming else "causal", + name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer @@ -167,7 +170,7 @@ def __init__(self, beta_regularizer=bias_regularizer ) self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") - self.pw_conv_2 = tf.keras.layers.Conv2D( + self.pw_conv_2 = tf.keras.layers.Conv1D( filters=input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_2", kernel_regularizer=kernel_regularizer, @@ -178,8 +181,7 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - B, T, E = shape_util.shape_list(outputs) - outputs = tf.reshape(outputs, [B, T, 1, E]) + B, T, E = shape_list(outputs) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) @@ -217,6 +219,7 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_block", **kwargs): super(ConformerBlock, self).__init__(name=name, **kwargs) @@ -238,7 +241,8 @@ def __init__(self, dropout=dropout, name=f"{name}_conv_module", depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, + streaming=streaming ) self.ffm2 = FFModule( input_dim=input_dim, dropout=dropout, @@ -286,6 +290,7 @@ def __init__(self, dropout=0.0, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_encoder", **kwargs): super(ConformerEncoder, self).__init__(name=name, **kwargs) @@ -338,6 +343,7 @@ def __init__(self, depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + streaming=streaming, name=f"{name}_block_{i}" ) self.conformer_blocks.append(conformer_block) diff --git a/tensorflow_asr/models/transducer/conformer.py b/tensorflow_asr/models/transducer/conformer.py index b5d151e266..be5435444e 100644 --- a/tensorflow_asr/models/transducer/conformer.py +++ b/tensorflow_asr/models/transducer/conformer.py @@ -49,6 +49,7 @@ def __init__(self, joint_trainable: bool = True, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name: str = "conformer", **kwargs): super(Conformer, self).__init__( @@ -67,6 +68,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, + streaming=streaming, name=f"{name}_encoder" ), vocabulary_size=vocabulary_size, From 9989dc491ff7a6566a2da3a12b832a527caf2518 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 12:48:50 -0500 Subject: [PATCH 10/31] Adapts ASRMaskedSliceDataset to refactored repo. --- tensorflow_asr/datasets/asr_dataset.py | 73 +++++++++++++++++++++----- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index dd3833cc2f..e354f839b1 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -443,8 +443,33 @@ def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): # -------------------------------- CREATION ------------------------------------- + def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + """ + Returns: + path, features, input_lengths, labels, label_lengths, pred_inp, mask + """ + if self.use_tf: data = self.tf_preprocess(path, audio, indices) + else: data = self.preprocess(path, audio, indices) + + _, features, input_length, label, label_length, prediction, prediction_length, mask = data + + return ( + data_util.create_inputs( + inputs=features, + inputs_length=input_length, + predictions=prediction, + predictions_length=prediction_length, + mask = mask + ), + data_util.create_labels( + labels=label, + labels_length=label_length + ) + ) + def process(self, dataset: tf.data.Dataset, batch_size: int): dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) + self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) if self.cache: dataset = dataset.cache() @@ -452,29 +477,53 @@ def process(self, dataset: tf.data.Dataset, batch_size: int): if self.shuffle: dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) - if self.indefinite: + if self.indefinite and self.total_steps: dataset = dataset.repeat() # PADDED BATCH the dataset dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=( - tf.TensorShape([]), - tf.TensorShape(self.speech_featurizer.shape), - tf.TensorShape([]), - tf.TensorShape(self.text_featurizer.shape), - tf.TensorShape([]), - tf.TensorShape(self.text_featurizer.prepand_shape), - tf.TensorShape([]), - tf.TensorShape([self.speech_featurizer.shape[0], self.speech_featurizer.shape[0]]) + data_util.create_inputs( + inputs=tf.TensorShape(self.speech_featurizer.shape), + inputs_length=tf.TensorShape([]), + predictions=tf.TensorShape(self.text_featurizer.prepand_shape), + predictions_length=tf.TensorShape([]), + mask=tf.TensorShape([self.speech_featurizer.shape[0], self.speech_featurizer.shape[0]]) + ), + data_util.create_labels( + labels=tf.TensorShape(self.text_featurizer.shape), + labels_length=tf.TensorShape([]) + ), + ), + padding_values=( + data_util.create_inputs( + inputs= 0., + inputs_length=0, + predictions=self.text_featurizer.blank, + predictions_length=0, + mask=0 + ), + data_util.create_labels( + labels=self.text_featurizer.blank, + labels_length=0 + ) ), - padding_values=(None, 0., 0, self.text_featurizer.blank, 0, self.text_featurizer.blank, 0, 0), - drop_remainder=self.drop_remainder + drop_remainder = self.drop_remainder ) # PREFETCH to improve speed of input length dataset = dataset.prefetch(AUTOTUNE) - self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) return dataset + def create(self, batch_size: int): + self.read_entries() + if not self.total_steps or self.total_steps == 0: return None + dataset = tf.data.Dataset.from_generator( + self.generator, + output_types=(tf.string, tf.string, tf.string), + output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])) + ) + return self.process(dataset, batch_size) + # TODO: Create masked TFRecords dataset From c2e3ec6d6bc8dde38ea6c541b3c072748e4d3481 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 12:50:03 -0500 Subject: [PATCH 11/31] Bugfix uses shape_list from shape_util. --- tensorflow_asr/models/encoders/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index fb37dd93c9..faeeb630e2 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -181,7 +181,7 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - B, T, E = shape_list(outputs) + B, T, E = shape_util.shape_list(outputs) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) From e3da4656427334a2e84012c4324283b93e3fffb5 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 12:51:16 -0500 Subject: [PATCH 12/31] Adds mask compatibilty for create_inputs. --- tensorflow_asr/utils/data_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_asr/utils/data_util.py b/tensorflow_asr/utils/data_util.py index 2bcdca8d4e..7c7d05f9f5 100644 --- a/tensorflow_asr/utils/data_util.py +++ b/tensorflow_asr/utils/data_util.py @@ -20,7 +20,8 @@ def create_inputs(inputs: tf.Tensor, inputs_length: tf.Tensor, predictions: tf.Tensor = None, - predictions_length: tf.Tensor = None) -> dict: + predictions_length: tf.Tensor = None, + mask: tf.Tensor = None) -> dict: data = { "inputs": inputs, "inputs_length": inputs_length, @@ -29,6 +30,8 @@ def create_inputs(inputs: tf.Tensor, data["predictions"] = predictions if predictions_length is not None: data["predictions_length"] = predictions_length + if mask is not None: + data["mask"] = mask return data From 0c53bc7f825e94c0735559d8e281b44ff036fe23 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 12:54:40 -0500 Subject: [PATCH 13/31] Adapts StreamingConformer to refactored repo. --- .../models/transducer/streaming_conformer.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tensorflow_asr/models/transducer/streaming_conformer.py diff --git a/tensorflow_asr/models/transducer/streaming_conformer.py b/tensorflow_asr/models/transducer/streaming_conformer.py new file mode 100644 index 0000000000..38d6c10a84 --- /dev/null +++ b/tensorflow_asr/models/transducer/streaming_conformer.py @@ -0,0 +1,134 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" http://arxiv.org/abs/1811.06621 """ + +import tensorflow as tf + +from ..layers.subsampling import TimeReduction +# from .transducer import Transducer +from ...utils import data_util, math_util +# from ...utils.utils import get_rnn, merge_two_last_dims, shape_list +from .conformer import Conformer + +L2 = tf.keras.regularizers.l2(1e-6) + +class StreamingConformer(Conformer): + """ + Attempt at implementing Streaming Conformer Transducer. (see: https://arxiv.org/pdf/2010.11395.pdf). + + Three main differences: + - Inputs are splits into chunks. + - Masking is used for MHSA to select the chunks to be used at each timestep. (Allows for parallel training.) + - Added parameter `streaming` to ConformerEncoder, ConformerBlock and ConvModule. Inside ConvModule, the layer DepthwiseConv2D has padding changed to "causal" when `streaming==True`. + + NOTE: Masking is applied just as regular masking along with the inputs. + """ + def __init__(self, + vocabulary_size: int, + encoder_subsampling: dict, + encoder_positional_encoding: str = "sinusoid", + encoder_dmodel: int = 144, + encoder_num_blocks: int = 16, + encoder_head_size: int = 36, + encoder_num_heads: int = 4, + encoder_mha_type: str = "relmha", + encoder_kernel_size: int = 32, + encoder_depth_multiplier: int = 1, + encoder_fc_factor: float = 0.5, + encoder_dropout: float = 0, + encoder_trainable: bool = True, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "streaming_conformer", + **kwargs): + + self.streaming = True # Hardcoded value. Initializes Conformer with `streaming = True`. + super(StreamingConformer, self).__init__( + vocabulary_size=vocabulary_size, + encoder_subsampling=encoder_subsampling, + encoder_positional_encoding=encoder_positional_encoding, + encoder_dmodel=encoder_dmodel, + encoder_num_blocks=encoder_num_blocks, + encoder_head_size=encoder_head_size, + encoder_num_heads=encoder_num_heads, + encoder_mha_type=encoder_mha_type, + encoder_depth_multiplier=encoder_depth_multiplier, + encoder_kernel_size=encoder_kernel_size, + encoder_fc_factor=encoder_fc_factor, + encoder_dropout=encoder_dropout, + encoder_trainable=encoder_trainable, + prediction_embed_dim=prediction_embed_dim, + prediction_embed_dropout=prediction_embed_dropout, + prediction_num_rnns=prediction_num_rnns, + prediction_rnn_units=prediction_rnn_units, + prediction_rnn_type=prediction_rnn_type, + prediction_rnn_implementation=prediction_rnn_implementation, + prediction_layer_norm=prediction_layer_norm, + prediction_projection_units=prediction_projection_units, + prediction_trainable=prediction_trainable, + joint_dim=joint_dim, + joint_activation=joint_activation, + prejoint_linear=prejoint_linear, + postjoint_linear=postjoint_linear, + joint_mode=joint_mode, + joint_trainable=joint_trainable, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + streaming=self.streaming, + name=name, + **kwargs + ) + self.dmodel = encoder_dmodel + self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor + + def _build(self, input_shape, prediction_shape=[None], batch_size=None): + inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + predictions_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + mask = tf.keras.Input(shape=[None, None], batch_size=batch_size, dtype=tf.int32) + self( + data_util.create_inputs( + inputs=inputs, + inputs_length=inputs_length, + predictions=predictions, + predictions_length=predictions_length, + mask=mask + ), + training=False + ) + + def call(self, inputs, training=False, **kwargs): + enc = self.encoder(inputs["inputs"], training=training, mask=inputs["mask"], **kwargs) + pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training, **kwargs) + logits = self.joint_net([enc, pred], training=training, **kwargs) + return data_util.create_logits( + logits=logits, + logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) + ) From 27ef8ddadd9fcb4e403d481a9e8da14f6f6b18be Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 12:57:43 -0500 Subject: [PATCH 14/31] Adapts StreamingConformer training script to refactored repo. --- examples/streaming_conformer/train.py | 153 ++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 examples/streaming_conformer/train.py diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py new file mode 100644 index 0000000000..0f62ca1a68 --- /dev/null +++ b/examples/streaming_conformer/train.py @@ -0,0 +1,153 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import env_util + +env_util.setup_environment() +import tensorflow as tf +physical_devices = tf.config.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(physical_devices[0], True) + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") + +parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") + +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata") + +parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = env_util.setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset +from tensorflow_asr.featurizers import speech_featurizers, text_featurizers +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.models.transducer.streaming_conformer import StreamingConformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config) +speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config) +elif args.subwords: + print("Loading subwords ...") + text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config) +else: + print("Use characters ...") + text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config) + +if args.tfrecords: + train_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) + ) +else: + train_dataset = ASRMaskedSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRMaskedSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) + ) + +train_dataset.load_metadata(args.metadata) +eval_dataset.load_metadata(args.metadata) + +if not args.static_length: + speech_featurizer.reset_length() + text_featurizer.reset_length() + +global_batch_size = args.tbs or config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +global_eval_batch_size = args.ebs or config.learning_config.running_config.eval_batch_size +global_eval_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_eval_batch_size) + +with strategy.scope(): + # build model + streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) + streaming_conformer._build(speech_featurizer.shape) + streaming_conformer.summary(line_length=150) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=streaming_conformer.dmodel, + warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000), + max_lr=(0.05 / math.sqrt(streaming_conformer.dmodel)) + ), + **config.learning_config.optimizer_config + ) + + streaming_conformer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +streaming_conformer.fit( + train_data_loader, + batch_size=global_batch_size, + epochs=config.learning_config.running_config.num_epochs, + steps_per_epoch=train_dataset.total_steps, + validation_data=eval_data_loader, + validation_batch_size=global_eval_batch_size, + validation_steps=eval_dataset.total_steps, + callbacks=callbacks, +) From 5a9aeb040c43b855c3bbc6559fcc78dfeb267b8a Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 16:45:23 -0500 Subject: [PATCH 15/31] Adds eval_batch_size and default value. --- examples/streaming_conformer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py index 0f62ca1a68..97b50871b1 100644 --- a/examples/streaming_conformer/train.py +++ b/examples/streaming_conformer/train.py @@ -107,7 +107,7 @@ global_batch_size = args.tbs or config.learning_config.running_config.batch_size global_batch_size *= strategy.num_replicas_in_sync -global_eval_batch_size = args.ebs or config.learning_config.running_config.eval_batch_size +global_eval_batch_size = args.ebs or global_batch_size global_eval_batch_size *= strategy.num_replicas_in_sync train_data_loader = train_dataset.create(global_batch_size) From 50def4735ed0b832b4d4f2a93f8b422b995bc79b Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 16:47:26 -0500 Subject: [PATCH 16/31] Adds DepthwiseConv1D layer from github. --- .../models/layers/DepthwiseConv1D.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tensorflow_asr/models/layers/DepthwiseConv1D.py diff --git a/tensorflow_asr/models/layers/DepthwiseConv1D.py b/tensorflow_asr/models/layers/DepthwiseConv1D.py new file mode 100644 index 0000000000..77928ee595 --- /dev/null +++ b/tensorflow_asr/models/layers/DepthwiseConv1D.py @@ -0,0 +1,221 @@ +""" + This implementation comes from github: https://github.com/tensorflow/tensorflow/issues/36935 + Slight modifications have been made to support causal padding. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import backend +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.input_spec import InputSpec +from tensorflow.python.keras.layers.convolutional import Conv1D +# imports for backwards namespace compatibility +# pylint: disable=unused-import +# pylint: enable=unused-import +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import keras_export + + +@keras_export('keras.layers.DepthwiseConv1D') +class DepthwiseConv1D(Conv1D): + """Depthwise separable 1D convolution. + Depthwise Separable convolutions consist of performing + just the first step in a depthwise spatial convolution + (which acts on each input channel separately). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + Arguments: + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `'valid'` or `'same'` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filters_in * depth_multiplier`. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, length)`. + The default is 'channels_last'. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. 'linear' activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix. + bias_initializer: Initializer for the bias vector. + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its 'activation'). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix. + bias_constraint: Constraint function applied to the bias vector. + Input shape: + 3D tensor with shape: + `[batch, channels, length]` if data_format='channels_first' + or 4D tensor with shape: + `[batch, length, channels]` if data_format='channels_last'. + Output shape: + 3D tensor with shape: + `[batch, filters, new_length]` if data_format='channels_first' + or 3D tensor with shape: + `[batch, new_length, filters]` if data_format='channels_last'. + `length` values might have changed due to padding. + """ + + def __init__(self, + kernel_size, + strides=1, + padding='valid', + depth_multiplier=1, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer='glorot_uniform', + bias_initializer='zeros', + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs): + super(DepthwiseConv1D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + # autocast=False, + **kwargs) + self.depth_multiplier = depth_multiplier + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + if len(input_shape) < 3: + raise ValueError('Inputs to `DepthwiseConv1D` should have rank 3. ' + 'Received input shape:', str(input_shape)) + input_shape = tensor_shape.TensorShape(input_shape) + + #TODO(pj1989): replace with channel_axis = self._get_channel_axis() + if self.data_format == 'channels_last': + channel_axis = -1 + elif self.data_format == 'channels_first': + channel_axis = 1 + + if input_shape.dims[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs to ' + '`DepthwiseConv1D` ' + 'should be defined. Found `None`.') + input_dim = int(input_shape[channel_axis]) + depthwise_kernel_shape = (self.kernel_size[0], + input_dim, + self.depth_multiplier) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name='depthwise_kernel', + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint) + + if self.use_bias: + self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,), + initializer=self.bias_initializer, + name='bias', + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs): + if self.padding == 'causal': + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) + if self.data_format == 'channels_last': + spatial_start_dim = 1 + else: + spatial_start_dim = 2 + + # Explicitly broadcast inputs and kernels to 4D. + # TODO(fchollet): refactor when a native depthwise_conv2d op is available. + strides = self.strides * 2 + inputs = array_ops.expand_dims(inputs, spatial_start_dim) + depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) + dilation_rate = (1,) + self.dilation_rate + + outputs = backend.depthwise_conv2d( + inputs, + depthwise_kernel, + strides=strides, + padding=self.padding if not self.padding == 'causal' else 'valid', + dilation_rate=dilation_rate, + data_format=self.data_format) + + if self.use_bias: + outputs = backend.bias_add( + outputs, + self.bias, + data_format=self.data_format) + + outputs = array_ops.squeeze(outputs, [spatial_start_dim]) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_first': + length = input_shape[2] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == 'channels_last': + length = input_shape[1] + out_filters = input_shape[2] * self.depth_multiplier + + length = conv_utils.conv_output_length(length, self.kernel_size, + self.padding, + self.strides) + if self.data_format == 'channels_first': + return (input_shape[0], out_filters, length) + elif self.data_format == 'channels_last': + return (input_shape[0], length, out_filters) + + def get_config(self): + config = super(DepthwiseConv1D, self).get_config() + config.pop('filters') + config.pop('kernel_initializer') + config.pop('kernel_regularizer') + config.pop('kernel_constraint') + config['depth_multiplier'] = self.depth_multiplier + config['depthwise_initializer'] = initializers.serialize( + self.depthwise_initializer) + config['depthwise_regularizer'] = regularizers.serialize( + self.depthwise_regularizer) + config['depthwise_constraint'] = constraints.serialize( + self.depthwise_constraint) \ No newline at end of file From 533e8ead4220d05cee16eef1d7de437835a7747a Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 16:48:07 -0500 Subject: [PATCH 17/31] Sets Conformer model to use new DepthwiseConv1D layer. --- tensorflow_asr/models/encoders/conformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index faeeb630e2..2dded6c8d4 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -19,6 +19,7 @@ from ..layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from ..layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention from ...utils import shape_util +from ..layers.DepthwiseConv1D import DepthwiseConv1D L2 = tf.keras.regularizers.l2(1e-6) @@ -155,14 +156,14 @@ def __init__(self, bias_regularizer=bias_regularizer ) self.glu = GLU(name=f"{name}_glu") - self.dw_conv = tf.keras.layers.SeparableConv1D( - filters=input_dim, + self.dw_conv = DepthwiseConv1D( kernel_size=(kernel_size), strides=1, padding="same" if not streaming else "causal", name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, + data_format='channels_last', ) self.bn = tf.keras.layers.BatchNormalization( name=f"{name}_bn", From d784a1f95ad731b73ccc9346f53939dc51f49eef Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 17:13:23 -0500 Subject: [PATCH 18/31] Loads time_reduction_factor dynamically into --- examples/streaming_conformer/train.py | 8 ++++---- tensorflow_asr/datasets/asr_dataset.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py index 97b50871b1..9dbaea6c5a 100644 --- a/examples/streaming_conformer/train.py +++ b/examples/streaming_conformer/train.py @@ -107,12 +107,9 @@ global_batch_size = args.tbs or config.learning_config.running_config.batch_size global_batch_size *= strategy.num_replicas_in_sync -global_eval_batch_size = args.ebs or global_batch_size +global_eval_batch_size = args.ebs or global_batch_size global_eval_batch_size *= strategy.num_replicas_in_sync -train_data_loader = train_dataset.create(global_batch_size) -eval_data_loader = eval_dataset.create(global_eval_batch_size) - with strategy.scope(): # build model streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) @@ -135,6 +132,9 @@ blank=text_featurizer.blank ) +train_data_loader = train_dataset.create(global_batch_size, time_reduction_factor=streaming_conformer.time_reduction_factor) +eval_data_loader = eval_dataset.create(global_eval_batch_size, time_reduction_factor=streaming_conformer.time_reduction_factor) + callbacks = [ tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index e354f839b1..f436075406 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -385,6 +385,7 @@ def __init__(self, buffer_size: int = BUFFER_SIZE, history_window_size: int = 3, input_chunk_duration: int = 250, + time_reduction_factor: int = 4, **kwargs): super(ASRMaskedSliceDataset, self).__init__( data_paths=data_paths, augmentations=augmentations, @@ -396,13 +397,13 @@ def __init__(self, self.text_featurizer = text_featurizer self.history_window_size = history_window_size self.input_chunk_size = input_chunk_duration * self.speech_featurizer.sample_rate // 1000 + self.time_reduction_factor = time_reduction_factor def calculate_mask(self, num_frames): frame_step = self.speech_featurizer.frame_step frames_per_chunk = self.input_chunk_size // frame_step - time_reduction_factor = 4 # TODO: Get time_reduction_factor from config or model. - num_frames = tf.cast(tf.math.ceil(num_frames / time_reduction_factor), tf.int32) + num_frames = tf.cast(tf.math.ceil(num_frames / self.time_reduction_factor), tf.int32) def _calculate_mask(num_frames, frames_per_chunk, history_window_size): mask = np.zeros((num_frames, num_frames), dtype=np.int32) From 3edfd47e2c3ed9cc69f6d206f0cca5e96b8375f6 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 19 Apr 2021 18:56:12 -0500 Subject: [PATCH 19/31] Bugfix. Loads time_reduction_factor dynamically. --- examples/streaming_conformer/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py index 9dbaea6c5a..e5e458fecd 100644 --- a/examples/streaming_conformer/train.py +++ b/examples/streaming_conformer/train.py @@ -78,6 +78,7 @@ print("Use characters ...") text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config) +time_reduction_factor = config.model_config['encoder_subsampling']['strides'] * 2 if args.tfrecords: train_dataset = ASRTFRecordDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, @@ -90,10 +91,12 @@ else: train_dataset = ASRMaskedSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + time_reduction_factor=time_reduction_factor, **vars(config.learning_config.train_dataset_config) ) eval_dataset = ASRMaskedSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + time_reduction_factor=time_reduction_factor, **vars(config.learning_config.eval_dataset_config) ) @@ -110,6 +113,9 @@ global_eval_batch_size = args.ebs or global_batch_size global_eval_batch_size *= strategy.num_replicas_in_sync +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_eval_batch_size) + with strategy.scope(): # build model streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) @@ -132,9 +138,6 @@ blank=text_featurizer.blank ) -train_data_loader = train_dataset.create(global_batch_size, time_reduction_factor=streaming_conformer.time_reduction_factor) -eval_data_loader = eval_dataset.create(global_eval_batch_size, time_reduction_factor=streaming_conformer.time_reduction_factor) - callbacks = [ tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), From 3ded7e133a703b5ff1e431d63a9b4945edbd03c3 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 20 Apr 2021 17:34:32 -0500 Subject: [PATCH 20/31] Removes problem causing imports. --- tensorflow_asr/models/layers/DepthwiseConv1D.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tensorflow_asr/models/layers/DepthwiseConv1D.py b/tensorflow_asr/models/layers/DepthwiseConv1D.py index 77928ee595..e9abe0d92a 100644 --- a/tensorflow_asr/models/layers/DepthwiseConv1D.py +++ b/tensorflow_asr/models/layers/DepthwiseConv1D.py @@ -3,10 +3,6 @@ Slight modifications have been made to support causal padding. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras import constraints @@ -14,16 +10,13 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.layers.convolutional import Conv1D -# imports for backwards namespace compatibility -# pylint: disable=unused-import -# pylint: enable=unused-import + from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.layers.DepthwiseConv1D') class DepthwiseConv1D(Conv1D): """Depthwise separable 1D convolution. Depthwise Separable convolutions consist of performing @@ -218,4 +211,4 @@ def get_config(self): config['depthwise_regularizer'] = regularizers.serialize( self.depthwise_regularizer) config['depthwise_constraint'] = constraints.serialize( - self.depthwise_constraint) \ No newline at end of file + self.depthwise_constraint) From 00032a712d2583389832033892e7924b08b69f9d Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 26 Apr 2021 16:49:39 -0500 Subject: [PATCH 21/31] Removes unnecessary argument. --- tensorflow_asr/models/encoders/conformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 2dded6c8d4..92eb259903 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -162,8 +162,7 @@ def __init__(self, name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - data_format='channels_last', + bias_regularizer=bias_regularizer ) self.bn = tf.keras.layers.BatchNormalization( name=f"{name}_bn", From a0223ece6b5263a874b8254559c7e2d6ddadfee5 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 26 Apr 2021 16:50:38 -0500 Subject: [PATCH 22/31] Removes unused lines from DepthwiseConv2D. --- tensorflow_asr/models/encoders/conformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 92eb259903..1130271ced 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -181,14 +181,12 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - B, T, E = shape_util.shape_list(outputs) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) outputs = self.pw_conv_2(outputs, training=training) - outputs = tf.reshape(outputs, [B, T, E]) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs From cfbc29ecafc490965aa727addaa00fc068447ecd Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Mon, 26 Apr 2021 16:53:56 -0500 Subject: [PATCH 23/31] Renames DepthwiseConv1D definition script. --- tensorflow_asr/models/encoders/conformer.py | 4 ++-- .../models/layers/{DepthwiseConv1D.py => depthwise_conv1d.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename tensorflow_asr/models/layers/{DepthwiseConv1D.py => depthwise_conv1d.py} (100%) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 1130271ced..782ec1c67d 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -19,7 +19,7 @@ from ..layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from ..layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention from ...utils import shape_util -from ..layers.DepthwiseConv1D import DepthwiseConv1D +from ..layers.depthwise_conv1d import DepthwiseConv1D L2 = tf.keras.regularizers.l2(1e-6) @@ -186,7 +186,7 @@ def call(self, inputs, training=False, **kwargs): outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) - outputs = self.pw_conv_2(outputs, training=training) + outputs = self.pw_conv_2(outputs, training=training)f outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs diff --git a/tensorflow_asr/models/layers/DepthwiseConv1D.py b/tensorflow_asr/models/layers/depthwise_conv1d.py similarity index 100% rename from tensorflow_asr/models/layers/DepthwiseConv1D.py rename to tensorflow_asr/models/layers/depthwise_conv1d.py From aadcf91085674a937efd7c18e9165d3e8748d4ac Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 27 Apr 2021 09:45:38 -0500 Subject: [PATCH 24/31] Bufgix, typo --- tensorflow_asr/models/encoders/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 782ec1c67d..6bc6cde65c 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -186,7 +186,7 @@ def call(self, inputs, training=False, **kwargs): outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) - outputs = self.pw_conv_2(outputs, training=training)f + outputs = self.pw_conv_2(outputs, training=training) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs From d593105bb62db534868d6db5fc5dedc9bac97ae3 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 27 Apr 2021 09:46:38 -0500 Subject: [PATCH 25/31] Renames model _build() to make(). --- examples/streaming_conformer/train.py | 2 +- tensorflow_asr/models/transducer/streaming_conformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py index e5e458fecd..377d9a5968 100644 --- a/examples/streaming_conformer/train.py +++ b/examples/streaming_conformer/train.py @@ -119,7 +119,7 @@ with strategy.scope(): # build model streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) - streaming_conformer._build(speech_featurizer.shape) + streaming_conformer.make(speech_featurizer.shape) streaming_conformer.summary(line_length=150) optimizer = tf.keras.optimizers.Adam( diff --git a/tensorflow_asr/models/transducer/streaming_conformer.py b/tensorflow_asr/models/transducer/streaming_conformer.py index 38d6c10a84..d26d4681b4 100644 --- a/tensorflow_asr/models/transducer/streaming_conformer.py +++ b/tensorflow_asr/models/transducer/streaming_conformer.py @@ -107,7 +107,7 @@ def __init__(self, self.dmodel = encoder_dmodel self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor - def _build(self, input_shape, prediction_shape=[None], batch_size=None): + def make(self, input_shape, prediction_shape=[None], batch_size=None): inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) From 9d66c2de00f617831fe5b763a97dd51087abb2d1 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 27 Apr 2021 10:27:42 -0500 Subject: [PATCH 26/31] Adds ASRMaskedTFRecordDataset. Fixes ASRTFRecordDataset. --- tensorflow_asr/datasets/asr_dataset.py | 138 ++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 3 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index f436075406..4427465aa0 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -334,8 +334,24 @@ def parse(self, record: tf.Tensor): "indices": tf.io.FixedLenFeature([], tf.string) } example = tf.io.parse_single_example(record, feature_description) - if self.use_tf: return self.tf_preprocess(**example) - return self.preprocess(**example) + if self.use_tf: data = self.tf_preprocess(**example) + else: data = self.preprocess(**example) + + _, features, input_length, label, label_length, prediction, prediction_length, mask = data + + return ( + data_util.create_inputs( + inputs=features, + inputs_length=input_length, + predictions=prediction, + predictions_length=prediction_length, + mask=mask + ), + data_util.create_labels( + labels=label, + labels_length=label_length + ) + ) def create(self, batch_size: int): have_data = self.create_tfrecords() @@ -527,4 +543,120 @@ def create(self, batch_size: int): ) return self.process(dataset, batch_size) -# TODO: Create masked TFRecords dataset +class ASRMaskedTFRecordDataset(ASRMaskedSliceDataset): + """ Dataset for ASR using TFRecords with rolling mask """ + def __init__(self, + data_paths: list, + tfrecords_dir: str, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + stage: str, + augmentations: Augmentation = Augmentation(None), + tfrecords_shards: int = TFRECORD_SHARDS, + cache: bool = False, + shuffle: bool = False, + use_tf: bool = False, + indefinite: bool = False, + drop_remainder: bool = True, + buffer_size: int = BUFFER_SIZE, + history_window_size: int = 3, + input_chunk_duration: int = 250, + time_reduction_factor: int = 4, + **kwargs): + super(ASRMaskedTFRecordDataset, self).__init__( + stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size, + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite, history_window_size = history_window_size, + input_chunk_duration = input_chunk_duration, time_reduction_factor = time_reduction_factor, + ) + if not self.stage: raise ValueError("stage must be defined, either 'train', 'eval' or 'test'") + self.tfrecords_dir = tfrecords_dir + if tfrecords_shards <= 0: raise ValueError("tfrecords_shards must be positive") + self.tfrecords_shards = tfrecords_shards + if not tf.io.gfile.exists(self.tfrecords_dir): tf.io.gfile.makedirs(self.tfrecords_dir) + + @staticmethod + def write_tfrecord_file(splitted_entries): + shard_path, entries = splitted_entries + + def parse(record): + def fn(path, indices): + audio = load_and_convert_to_wav(path.decode("utf-8")).numpy() + feature = { + "path": feature_util.bytestring_feature([path]), + "audio": feature_util.bytestring_feature([audio]), + "indices": feature_util.bytestring_feature([indices]) + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + return example.SerializeToString() + return tf.numpy_function(fn, inp=[record[0], record[2]], Tout=tf.string) + + dataset = tf.data.Dataset.from_tensor_slices(entries) + dataset = dataset.map(parse, num_parallel_calls=AUTOTUNE) + writer = tf.data.experimental.TFRecordWriter(shard_path, compression_type="ZLIB") + print(f"Processing {shard_path} ...") + writer.write(dataset) + print(f"Created {shard_path}") + + def create_tfrecords(self): + if not tf.io.gfile.exists(self.tfrecords_dir): + tf.io.gfile.makedirs(self.tfrecords_dir) + + if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")): + print(f"TFRecords're already existed: {self.stage}") + return True + + print(f"Creating {self.stage}.tfrecord ...") + + self.read_entries() + if not self.total_steps or self.total_steps == 0: return False + + def get_shard_path(shard_id): + return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord") + + shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)] + + splitted_entries = np.array_split(self.entries, self.tfrecords_shards) + for entries in zip(shards, splitted_entries): + self.write_tfrecord_file(entries) + + return True + + def parse(self, record: tf.Tensor): + feature_description = { + "path": tf.io.FixedLenFeature([], tf.string), + "audio": tf.io.FixedLenFeature([], tf.string), + "indices": tf.io.FixedLenFeature([], tf.string) + } + example = tf.io.parse_single_example(record, feature_description) + if self.use_tf: data = self.tf_preprocess(**example) + else: data = self.preprocess(**example) + + _, features, input_length, label, label_length, prediction, prediction_length, mask = data + + return ( + data_util.create_inputs( + inputs=features, + inputs_length=input_length, + predictions=prediction, + predictions_length=prediction_length, + mask=mask + ), + data_util.create_labels( + labels=label, + labels_length=label_length + ) + ) + + def create(self, batch_size: int): + have_data = self.create_tfrecords() + if not have_data: return None + + pattern = os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord") + files_ds = tf.data.Dataset.list_files(pattern) + ignore_order = tf.data.Options() + ignore_order.experimental_deterministic = False + files_ds = files_ds.with_options(ignore_order) + dataset = tf.data.TFRecordDataset(files_ds, compression_type="ZLIB", num_parallel_reads=AUTOTUNE) + + return self.process(dataset, batch_size) From 34525a1e81ff0a88d33b7b38e004b60daac4313b Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Wed, 28 Apr 2021 11:10:09 -0500 Subject: [PATCH 27/31] Fixes pep8 formatting. --- .../models/layers/depthwise_conv1d.py | 391 +++++++++--------- 1 file changed, 195 insertions(+), 196 deletions(-) diff --git a/tensorflow_asr/models/layers/depthwise_conv1d.py b/tensorflow_asr/models/layers/depthwise_conv1d.py index e9abe0d92a..9fc99c4e15 100644 --- a/tensorflow_asr/models/layers/depthwise_conv1d.py +++ b/tensorflow_asr/models/layers/depthwise_conv1d.py @@ -1,6 +1,6 @@ """ - This implementation comes from github: https://github.com/tensorflow/tensorflow/issues/36935 - Slight modifications have been made to support causal padding. + This implementation comes from github: https://github.com/tensorflow/tensorflow/issues/36935 + Slight modifications have been made to support causal padding. """ from tensorflow.python.framework import tensor_shape @@ -18,197 +18,196 @@ class DepthwiseConv1D(Conv1D): - """Depthwise separable 1D convolution. - Depthwise Separable convolutions consist of performing - just the first step in a depthwise spatial convolution - (which acts on each input channel separately). - The `depth_multiplier` argument controls how many - output channels are generated per input channel in the depthwise step. - Arguments: - kernel_size: A single integer specifying the spatial - dimensions of the filters. - strides: A single integer specifying the strides - of the convolution. - Specifying any `stride` value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: one of `'valid'` or `'same'` (case-insensitive). - depth_multiplier: The number of depthwise convolution output channels - for each input channel. - The total number of depthwise convolution output - channels will be equal to `filters_in * depth_multiplier`. - data_format: A string, - one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, length)`. - The default is 'channels_last'. - activation: Activation function to use. - If you don't specify anything, no activation is applied - (ie. 'linear' activation: `a(x) = x`). - use_bias: Boolean, whether the layer uses a bias vector. - depthwise_initializer: Initializer for the depthwise kernel matrix. - bias_initializer: Initializer for the bias vector. - depthwise_regularizer: Regularizer function applied to - the depthwise kernel matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its 'activation'). - depthwise_constraint: Constraint function applied to - the depthwise kernel matrix. - bias_constraint: Constraint function applied to the bias vector. - Input shape: - 3D tensor with shape: - `[batch, channels, length]` if data_format='channels_first' - or 4D tensor with shape: - `[batch, length, channels]` if data_format='channels_last'. - Output shape: - 3D tensor with shape: - `[batch, filters, new_length]` if data_format='channels_first' - or 3D tensor with shape: - `[batch, new_length, filters]` if data_format='channels_last'. - `length` values might have changed due to padding. - """ - - def __init__(self, - kernel_size, - strides=1, - padding='valid', - depth_multiplier=1, - data_format=None, - activation=None, - use_bias=True, - depthwise_initializer='glorot_uniform', - bias_initializer='zeros', - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs): - super(DepthwiseConv1D, self).__init__( - filters=None, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - bias_constraint=bias_constraint, - # autocast=False, - **kwargs) - self.depth_multiplier = depth_multiplier - self.depthwise_initializer = initializers.get(depthwise_initializer) - self.depthwise_regularizer = regularizers.get(depthwise_regularizer) - self.depthwise_constraint = constraints.get(depthwise_constraint) - self.bias_initializer = initializers.get(bias_initializer) - - def build(self, input_shape): - if len(input_shape) < 3: - raise ValueError('Inputs to `DepthwiseConv1D` should have rank 3. ' - 'Received input shape:', str(input_shape)) - input_shape = tensor_shape.TensorShape(input_shape) - - #TODO(pj1989): replace with channel_axis = self._get_channel_axis() - if self.data_format == 'channels_last': - channel_axis = -1 - elif self.data_format == 'channels_first': - channel_axis = 1 - - if input_shape.dims[channel_axis].value is None: - raise ValueError('The channel dimension of the inputs to ' - '`DepthwiseConv1D` ' - 'should be defined. Found `None`.') - input_dim = int(input_shape[channel_axis]) - depthwise_kernel_shape = (self.kernel_size[0], - input_dim, - self.depth_multiplier) - - self.depthwise_kernel = self.add_weight( - shape=depthwise_kernel_shape, - initializer=self.depthwise_initializer, - name='depthwise_kernel', - regularizer=self.depthwise_regularizer, - constraint=self.depthwise_constraint) - - if self.use_bias: - self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,), - initializer=self.bias_initializer, - name='bias', - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None - # Set input spec. - self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim}) - self.built = True - - def call(self, inputs): - if self.padding == 'causal': - inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) - if self.data_format == 'channels_last': - spatial_start_dim = 1 - else: - spatial_start_dim = 2 - - # Explicitly broadcast inputs and kernels to 4D. - # TODO(fchollet): refactor when a native depthwise_conv2d op is available. - strides = self.strides * 2 - inputs = array_ops.expand_dims(inputs, spatial_start_dim) - depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) - dilation_rate = (1,) + self.dilation_rate - - outputs = backend.depthwise_conv2d( - inputs, - depthwise_kernel, - strides=strides, - padding=self.padding if not self.padding == 'causal' else 'valid', - dilation_rate=dilation_rate, - data_format=self.data_format) - - if self.use_bias: - outputs = backend.bias_add( - outputs, - self.bias, - data_format=self.data_format) - - outputs = array_ops.squeeze(outputs, [spatial_start_dim]) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == 'channels_first': - length = input_shape[2] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == 'channels_last': - length = input_shape[1] - out_filters = input_shape[2] * self.depth_multiplier - - length = conv_utils.conv_output_length(length, self.kernel_size, - self.padding, - self.strides) - if self.data_format == 'channels_first': - return (input_shape[0], out_filters, length) - elif self.data_format == 'channels_last': - return (input_shape[0], length, out_filters) - - def get_config(self): - config = super(DepthwiseConv1D, self).get_config() - config.pop('filters') - config.pop('kernel_initializer') - config.pop('kernel_regularizer') - config.pop('kernel_constraint') - config['depth_multiplier'] = self.depth_multiplier - config['depthwise_initializer'] = initializers.serialize( - self.depthwise_initializer) - config['depthwise_regularizer'] = regularizers.serialize( - self.depthwise_regularizer) - config['depthwise_constraint'] = constraints.serialize( - self.depthwise_constraint) + """Depthwise separable 1D convolution. + Depthwise Separable convolutions consist of performing + just the first step in a depthwise spatial convolution + (which acts on each input channel separately). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + Arguments: + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `'valid'` or `'same'` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filters_in * depth_multiplier`. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, length)`. + The default is 'channels_last'. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. 'linear' activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix. + bias_initializer: Initializer for the bias vector. + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its 'activation'). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix. + bias_constraint: Constraint function applied to the bias vector. + Input shape: + 3D tensor with shape: + `[batch, channels, length]` if data_format='channels_first' + or 4D tensor with shape: + `[batch, length, channels]` if data_format='channels_last'. + Output shape: + 3D tensor with shape: + `[batch, filters, new_length]` if data_format='channels_first' + or 3D tensor with shape: + `[batch, new_length, filters]` if data_format='channels_last'. + `length` values might have changed due to padding. + """ + + def __init__(self, + kernel_size, + strides=1, + padding='valid', + depth_multiplier=1, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer='glorot_uniform', + bias_initializer='zeros', + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs): + super(DepthwiseConv1D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + # autocast=False, + **kwargs) + self.depth_multiplier = depth_multiplier + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + if len(input_shape) < 3: + raise ValueError('Inputs to `DepthwiseConv1D` should have rank 3. ' + 'Received input shape:', str(input_shape)) + input_shape = tensor_shape.TensorShape(input_shape) + + # TODO(pj1989): replace with channel_axis = self._get_channel_axis() + if self.data_format == 'channels_last': + channel_axis = -1 + elif self.data_format == 'channels_first': + channel_axis = 1 + + if input_shape.dims[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs to ' + '`DepthwiseConv1D` ' + 'should be defined. Found `None`.') + input_dim = int(input_shape[channel_axis]) + depthwise_kernel_shape = (self.kernel_size[0], input_dim, self.depth_multiplier) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name='depthwise_kernel', + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint) + + if self.use_bias: + self.bias = self.add_weight( + shape=(input_dim * self.depth_multiplier,), + initializer=self.bias_initializer, + name='bias', + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs): + if self.padding == 'causal': + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) + if self.data_format == 'channels_last': + spatial_start_dim = 1 + else: + spatial_start_dim = 2 + + # Explicitly broadcast inputs and kernels to 4D. + # TODO(fchollet): refactor when a native depthwise_conv2d op is available. + strides = self.strides * 2 + inputs = array_ops.expand_dims(inputs, spatial_start_dim) + depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) + dilation_rate = (1,) + self.dilation_rate + + outputs = backend.depthwise_conv2d( + inputs, + depthwise_kernel, + strides=strides, + padding=self.padding if not self.padding == 'causal' else 'valid', + dilation_rate=dilation_rate, + data_format=self.data_format) + + if self.use_bias: + outputs = backend.bias_add( + outputs, + self.bias, + data_format=self.data_format) + + outputs = array_ops.squeeze(outputs, [spatial_start_dim]) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_first': + length = input_shape[2] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == 'channels_last': + length = input_shape[1] + out_filters = input_shape[2] * self.depth_multiplier + + length = conv_utils.conv_output_length(length, self.kernel_size, + self.padding, + self.strides) + if self.data_format == 'channels_first': + return (input_shape[0], out_filters, length) + elif self.data_format == 'channels_last': + return (input_shape[0], length, out_filters) + + def get_config(self): + config = super(DepthwiseConv1D, self).get_config() + config.pop('filters') + config.pop('kernel_initializer') + config.pop('kernel_regularizer') + config.pop('kernel_constraint') + config['depth_multiplier'] = self.depth_multiplier + config['depthwise_initializer'] = initializers.serialize( + self.depthwise_initializer) + config['depthwise_regularizer'] = regularizers.serialize( + self.depthwise_regularizer) + config['depthwise_constraint'] = constraints.serialize( + self.depthwise_constraint) From 6d4bfac598fbb6439c79c369b98ec6eab2e097ec Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Thu, 29 Apr 2021 11:36:46 -0500 Subject: [PATCH 28/31] Adds _create_mask_tf for pure TF mask creation. Adds mask pre-compute when input max_length is defined. --- tensorflow_asr/datasets/asr_dataset.py | 81 +++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 9 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 4427465aa0..d086a7d848 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -415,11 +415,17 @@ def __init__(self, self.input_chunk_size = input_chunk_duration * self.speech_featurizer.sample_rate // 1000 self.time_reduction_factor = time_reduction_factor - def calculate_mask(self, num_frames): - frame_step = self.speech_featurizer.frame_step - frames_per_chunk = self.input_chunk_size // frame_step + self.base_mask = tf.constant(0, dtype=tf.int32) + self.use_base_mask = False - num_frames = tf.cast(tf.math.ceil(num_frames / self.time_reduction_factor), tf.int32) + # If max input length is known, pre-compute mask + if self.speech_featurizer.max_length: + num_frames = 1 + (self.speech_featurizer.max_length - self.speech_featurizer.frame_length) // self.speech_featurizer.frame_step + self.base_mask = self.calculate_mask(num_frames) + self.use_base_mask = True + + def _recalculate_mask(self, num_frames): + frames_per_chunk = self.input_chunk_size // self.speech_featurizer.frame_step def _calculate_mask(num_frames, frames_per_chunk, history_window_size): mask = np.zeros((num_frames, num_frames), dtype=np.int32) @@ -436,16 +442,73 @@ def _calculate_mask(num_frames, frames_per_chunk, history_window_size): mask[i, base_index + j] = 1 return mask - return tf.numpy_function( - _calculate_mask, inp=[num_frames, frames_per_chunk, self.history_window_size], Tout=tf.int32 - ) + @tf.function(autograph=True) + def _calculate_mask_tf(num_frames, frames_per_chunk, history_window_size): + chunk_ids = tf.range(num_frames) // frames_per_chunk + num_chunks = tf.cast(tf.math.ceil(num_frames / frames_per_chunk), dtype=tf.int32) + + # Create first `frames_per_chunk` rows + current = tf.ones((frames_per_chunk), dtype=tf.int32) + trailing = tf.ones(((num_chunks - 1) * frames_per_chunk), dtype=tf.int32) + tmp_row = tf.concat((current, trailing), axis=0) + row = tf.slice(tmp_row, [0], [num_frames]) + mask = tf.expand_dims(row, axis=0) + + for i in range(1, frames_per_chunk): + mask = tf.concat((mask, [row]), axis=0) + + # Create the following rows + for i in range(frames_per_chunk, num_frames): + tf.autograph.experimental.set_loop_options( + shape_invariants=[(mask, tf.TensorShape([None, None]))] + ) + curr_chunk_id = chunk_ids[i] + hist_i = tf.math.maximum(i - history_window_size, 0) + hist_chunk_id = chunk_ids[hist_i] + + # Build the left-most part + leading_chunk_id = hist_chunk_id - 1 + num_leading_chunks = tf.math.maximum(leading_chunk_id + 1, 0) + leftmost_row = tf.zeros((num_leading_chunks * frames_per_chunk), dtype=tf.int32) + + # Build the current visible chunks + num_hist_chunks = curr_chunk_id - hist_chunk_id + num_visible_chunks = num_hist_chunks + 1 + curr_chunk_row = tf.ones((num_visible_chunks * frames_per_chunk), dtype=tf.int32) + + # Build the trailing 0s + num_trailing_chunks = tf.math.maximum((num_chunks - curr_chunk_id) - 1, 0) + trailing_chunk_row = tf.zeros((num_trailing_chunks * frames_per_chunk), dtype=tf.int32) + + # Merge chunks, clip to output size + tmp_row = tf.concat([leftmost_row, curr_chunk_row, trailing_chunk_row], axis=0) + row = tf.slice(tmp_row, [0], [num_frames]) + + mask = tf.concat((mask, [row]), axis=0) + return mask + + if self.use_tf: + mask = _calculate_mask_tf(num_frames, frames_per_chunk, self.history_window_size) + else: + mask = tf.numpy_function( + _calculate_mask, inp=[num_frames, frames_per_chunk, self.history_window_size], Tout=tf.int32 + ) + mask.set_shape((None, None)) + + return mask + + def calculate_mask(self, num_frames): + num_frames = tf.cast(tf.math.ceil(num_frames / self.time_reduction_factor), tf.int32) + + if self.use_base_mask: + return tf.slice(self.base_mask, [0, 0], [num_frames, num_frames]) + return self._recalculate_mask(num_frames) def preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): preprocessed_inputs = super(ASRMaskedSliceDataset, self).preprocess(path, audio, indices) input_length = preprocessed_inputs[2] mask = self.calculate_mask(input_length) - mask.set_shape((None, None)) return (*preprocessed_inputs, mask) @@ -454,7 +517,6 @@ def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): input_length = preprocessed_inputs[2] mask = self.calculate_mask(input_length) - mask.set_shape((None, None)) return (*preprocessed_inputs, mask) @@ -543,6 +605,7 @@ def create(self, batch_size: int): ) return self.process(dataset, batch_size) + class ASRMaskedTFRecordDataset(ASRMaskedSliceDataset): """ Dataset for ASR using TFRecords with rolling mask """ def __init__(self, From b4f3d7203344e5d2d607fe555fcd402b760754ea Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Thu, 29 Apr 2021 11:38:36 -0500 Subject: [PATCH 29/31] Adds use of ASRMaskedTFRecordDataset. --- examples/streaming_conformer/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py index 377d9a5968..94a6297117 100644 --- a/examples/streaming_conformer/train.py +++ b/examples/streaming_conformer/train.py @@ -59,7 +59,7 @@ strategy = env_util.setup_strategy(args.devices) from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset +from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset, ASRMaskedTFRecordDataset from tensorflow_asr.featurizers import speech_featurizers, text_featurizers from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.models.transducer.streaming_conformer import StreamingConformer @@ -80,11 +80,11 @@ time_reduction_factor = config.model_config['encoder_subsampling']['strides'] * 2 if args.tfrecords: - train_dataset = ASRTFRecordDataset( + train_dataset = ASRMaskedTFRecordDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config) ) - eval_dataset = ASRTFRecordDataset( + eval_dataset = ASRMaskedTFRecordDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config) ) @@ -110,7 +110,7 @@ global_batch_size = args.tbs or config.learning_config.running_config.batch_size global_batch_size *= strategy.num_replicas_in_sync -global_eval_batch_size = args.ebs or global_batch_size +global_eval_batch_size = args.ebs or global_batch_size global_eval_batch_size *= strategy.num_replicas_in_sync train_data_loader = train_dataset.create(global_batch_size) From 73959ddcf6c3208194a725924e57463ccc4233b9 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 18 May 2021 11:59:05 -0500 Subject: [PATCH 30/31] Change request: Use math_util.get_reduced_length. --- tensorflow_asr/datasets/asr_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index a834cb061a..595aa0f129 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -482,7 +482,7 @@ def _calculate_mask_tf(num_frames, frames_per_chunk, history_window_size): return mask def calculate_mask(self, num_frames): - num_frames = tf.cast(tf.math.ceil(num_frames / self.time_reduction_factor), tf.int32) + num_frames = math_util.get_reduced_length(num_frames, self.time_reduction_factor) if self.use_base_mask: return tf.slice(self.base_mask, [0, 0], [num_frames, num_frames]) From 7d743eee61771793aefa40a8bfb85111d97748a5 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Tue, 18 May 2021 14:19:32 -0500 Subject: [PATCH 31/31] Change request: ASRMaskedTFRecordDataset inherits from two classes. --- tensorflow_asr/datasets/asr_dataset.py | 62 +------------------------- 1 file changed, 1 insertion(+), 61 deletions(-) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index 595aa0f129..c0edd02e92 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -590,7 +590,7 @@ def create(self, batch_size: int): return self.process(dataset, batch_size) -class ASRMaskedTFRecordDataset(ASRMaskedSliceDataset): +class ASRMaskedTFRecordDataset(ASRMaskedSliceDataset, ASRTFRecordDataset): """ Dataset for ASR using TFRecords with rolling mask """ def __init__(self, data_paths: list, @@ -622,53 +622,6 @@ def __init__(self, self.tfrecords_shards = tfrecords_shards if not tf.io.gfile.exists(self.tfrecords_dir): tf.io.gfile.makedirs(self.tfrecords_dir) - @staticmethod - def write_tfrecord_file(splitted_entries): - shard_path, entries = splitted_entries - - def parse(record): - def fn(path, indices): - audio = load_and_convert_to_wav(path.decode("utf-8")).numpy() - feature = { - "path": feature_util.bytestring_feature([path]), - "audio": feature_util.bytestring_feature([audio]), - "indices": feature_util.bytestring_feature([indices]) - } - example = tf.train.Example(features=tf.train.Features(feature=feature)) - return example.SerializeToString() - return tf.numpy_function(fn, inp=[record[0], record[2]], Tout=tf.string) - - dataset = tf.data.Dataset.from_tensor_slices(entries) - dataset = dataset.map(parse, num_parallel_calls=AUTOTUNE) - writer = tf.data.experimental.TFRecordWriter(shard_path, compression_type="ZLIB") - print(f"Processing {shard_path} ...") - writer.write(dataset) - print(f"Created {shard_path}") - - def create_tfrecords(self): - if not tf.io.gfile.exists(self.tfrecords_dir): - tf.io.gfile.makedirs(self.tfrecords_dir) - - if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")): - print(f"TFRecords're already existed: {self.stage}") - return True - - print(f"Creating {self.stage}.tfrecord ...") - - self.read_entries() - if not self.total_steps or self.total_steps == 0: return False - - def get_shard_path(shard_id): - return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord") - - shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)] - - splitted_entries = np.array_split(self.entries, self.tfrecords_shards) - for entries in zip(shards, splitted_entries): - self.write_tfrecord_file(entries) - - return True - def parse(self, record: tf.Tensor): feature_description = { "path": tf.io.FixedLenFeature([], tf.string), @@ -694,16 +647,3 @@ def parse(self, record: tf.Tensor): labels_length=label_length ) ) - - def create(self, batch_size: int): - have_data = self.create_tfrecords() - if not have_data: return None - - pattern = os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord") - files_ds = tf.data.Dataset.list_files(pattern) - ignore_order = tf.data.Options() - ignore_order.experimental_deterministic = False - files_ds = files_ds.with_options(ignore_order) - dataset = tf.data.TFRecordDataset(files_ds, compression_type="ZLIB", num_parallel_reads=AUTOTUNE) - - return self.process(dataset, batch_size)