Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Streaming Conformer Transducer #178

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1356ceb
Adds parameter to specify model behaviour.
andreselizondo-adestech Apr 15, 2021
47eba7d
Initial untested commit. Training script for Streaming Conformer Tran…
andreselizondo-adestech Apr 15, 2021
8d70d9b
Renames variables.
andreselizondo-adestech Apr 15, 2021
ac216b2
Adds ASRMaskedSliceDataset, for generating rolling masks for Streamin…
andreselizondo-adestech Apr 16, 2021
f4988fb
Changes DepthwiseConv2D for SeparableConv1D.
andreselizondo-adestech Apr 16, 2021
a6bc25f
Adds StreamingConformer class. Cleanup pending.
andreselizondo-adestech Apr 16, 2021
a4a2b18
Adds MaskedTransducerTrainer for training Streaming Conformer.
andreselizondo-adestech Apr 16, 2021
f8491e6
Configures for training Streaming Conformer.
andreselizondo-adestech Apr 16, 2021
cabbd43
Merge pull request #1 from TensorSpeech/main
andreselizondo-adestech Apr 19, 2021
64a354f
Adds ASRMaskedSliceDataset to refactored repo.
andreselizondo-adestech Apr 19, 2021
52ad653
Adds streaming and changes Conv2D for Conv1D to refactored repo.
andreselizondo-adestech Apr 19, 2021
9989dc4
Adapts ASRMaskedSliceDataset to refactored repo.
andreselizondo-adestech Apr 19, 2021
c2e3ec6
Bugfix uses shape_list from shape_util.
andreselizondo-adestech Apr 19, 2021
e3da465
Adds mask compatibilty for create_inputs.
andreselizondo-adestech Apr 19, 2021
0c53bc7
Adapts StreamingConformer to refactored repo.
andreselizondo-adestech Apr 19, 2021
27ef8dd
Adapts StreamingConformer training script to refactored repo.
andreselizondo-adestech Apr 19, 2021
5a9aeb0
Adds eval_batch_size and default value.
andreselizondo-adestech Apr 19, 2021
50def47
Adds DepthwiseConv1D layer from github.
andreselizondo-adestech Apr 19, 2021
533e8ea
Sets Conformer model to use new DepthwiseConv1D layer.
andreselizondo-adestech Apr 19, 2021
d784a1f
Loads time_reduction_factor dynamically into
andreselizondo-adestech Apr 19, 2021
3edfd47
Bugfix. Loads time_reduction_factor dynamically.
andreselizondo-adestech Apr 19, 2021
3ded7e1
Removes problem causing imports.
andreselizondo-adestech Apr 20, 2021
00032a7
Removes unnecessary argument.
andreselizondo-adestech Apr 26, 2021
a0223ec
Removes unused lines from DepthwiseConv2D.
andreselizondo-adestech Apr 26, 2021
cfbc29e
Renames DepthwiseConv1D definition script.
andreselizondo-adestech Apr 26, 2021
aadcf91
Bufgix, typo
andreselizondo-adestech Apr 27, 2021
d593105
Renames model _build() to make().
andreselizondo-adestech Apr 27, 2021
9d66c2d
Adds ASRMaskedTFRecordDataset. Fixes ASRTFRecordDataset.
andreselizondo-adestech Apr 27, 2021
34525a1
Fixes pep8 formatting.
andreselizondo-adestech Apr 28, 2021
6d4bfac
Adds _create_mask_tf for pure TF mask creation.
andreselizondo-adestech Apr 29, 2021
b4f3d72
Adds use of ASRMaskedTFRecordDataset.
andreselizondo-adestech Apr 29, 2021
6a6d5a3
Merge pull request #2 from TensorSpeech/main
andreselizondo-adestech Apr 29, 2021
3254450
Merge branch 'main' into tmp_merge
andreselizondo-adestech Apr 29, 2021
73959dd
Change request: Use math_util.get_reduced_length.
andreselizondo-adestech May 18, 2021
7d743ee
Change request: ASRMaskedTFRecordDataset inherits from two classes.
andreselizondo-adestech May 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions examples/streaming_conformer/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import argparse
from tensorflow_asr.utils import env_util

env_util.setup_environment()
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")

parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")

parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")

parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")

parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance")

parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata")

parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths")

parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")

parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})

strategy = env_util.setup_strategy(args.devices)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.transducer.streaming_conformer import StreamingConformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config)
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)

if args.sentence_piece:
print("Loading SentencePiece model ...")
text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config)
elif args.subwords:
print("Loading subwords ...")
text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config)
else:
print("Use characters ...")
text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)

time_reduction_factor = config.model_config['encoder_subsampling']['strides'] * 2
if args.tfrecords:
train_dataset = ASRTFRecordDataset(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.train_dataset_config)
)
eval_dataset = ASRTFRecordDataset(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
**vars(config.learning_config.eval_dataset_config)
)
else:
train_dataset = ASRMaskedSliceDataset(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
time_reduction_factor=time_reduction_factor,
**vars(config.learning_config.train_dataset_config)
)
eval_dataset = ASRMaskedSliceDataset(
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
time_reduction_factor=time_reduction_factor,
**vars(config.learning_config.eval_dataset_config)
)

train_dataset.load_metadata(args.metadata)
eval_dataset.load_metadata(args.metadata)

if not args.static_length:
speech_featurizer.reset_length()
text_featurizer.reset_length()

global_batch_size = args.tbs or config.learning_config.running_config.batch_size
global_batch_size *= strategy.num_replicas_in_sync

global_eval_batch_size = args.ebs or global_batch_size
global_eval_batch_size *= strategy.num_replicas_in_sync

train_data_loader = train_dataset.create(global_batch_size)
eval_data_loader = eval_dataset.create(global_eval_batch_size)

with strategy.scope():
# build model
streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
streaming_conformer.make(speech_featurizer.shape)
streaming_conformer.summary(line_length=150)

optimizer = tf.keras.optimizers.Adam(
TransformerSchedule(
d_model=streaming_conformer.dmodel,
warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
max_lr=(0.05 / math.sqrt(streaming_conformer.dmodel))
),
**config.learning_config.optimizer_config
)

streaming_conformer.compile(
optimizer=optimizer,
experimental_steps_per_execution=args.spx,
global_batch_size=global_batch_size,
blank=text_featurizer.blank
)

callbacks = [
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
]

streaming_conformer.fit(
train_data_loader,
batch_size=global_batch_size,
epochs=config.learning_config.running_config.num_epochs,
steps_per_epoch=train_dataset.total_steps,
validation_data=eval_data_loader,
validation_batch_size=global_eval_batch_size,
validation_steps=eval_dataset.total_steps,
callbacks=callbacks,
)
Loading