Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework PIT example train.py and data.py #125

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

sibange
Copy link

@sibange sibange commented Nov 8, 2021

No description provided.

audio_keys = ['observation', 'speech_source']
def prepare_dataset(db, dataset_name: str, batch_size, prefetch=True, shuffle=True):
"""
Prepares the dataset for the training process (loading audio data, SFTF)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: SFTF -> STFT

shuffle: should the data be shuffeled

Returns:
desired dataset of the database in prepared for the training
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is wrong with the grammar in this sentence

_config: Configuration dict of the experiment
_run: Run object of the current run of the experiment

Returns:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be left out when there is not return value

None
"""
init(_config, _run)
(trainer, train_dataset, validate_dataset) = prepare(_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parentheses on the left-hand-side are redundant

# Test run to detects possible errors in the trainer/datasets
trainer.test_run(train_dataset, validate_dataset)

# path where the checkpoints of the training are stored
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is lower case, others are upper case. Stick to one (I prefer upper case)

if shuffle:
dataset = dataset.shuffle(reshuffle=True)

#Splitting the dataset in batches and sorting the frames in the batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better write "... and sorts examples in a batch w.r.t. their duration" or something similar.
The frames themselves are not sorted

def pre_batch_transform(inputs, return_keys=None):
def pre_batch_transform(inputs):
"""
Prepares the data through creating a dictionary with various data, which is computed through STFT.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... by creating a dictionary with all data that is necessary for the model (e.g. STFT of observation)"

""" Prepares the train and validation dataset from the database object """
def prepare(_config):
"""
Preparation of the train and validation datasets for the training and initialisation of the padertorch trainer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to stick to American English. intitialisation -> initialization


sacred.commands.print_config(_run)
# Initialisation of the trainer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intitialisation -> initialization

checkpoint_path = trainer.checkpoint_dir / 'ckpt_latest.pth'

# Start of the training
trainer.register_validation_hook(validate_dataset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you repeat the most important default arguments of the validation hook, so that it becomes clear, what options can be easily modified for the validation (number of checkpoints, metric for the best checkpoint, ...)

@thequilo
Copy link
Member

Now that the maybe_add function disappeared from the data preparation, can you add an example_to_device method to the model that only transfers to the GPU those keys from the example that are required for training? You can use pt.data.batch.example_to_device for that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants