Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper and FLEURS #10

Merged
merged 53 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
08fad7c
refactor: Move `load_model_setup` to separate `model_setup` module
saattrupdan Jul 31, 2023
5859c15
style: Use getattr for Processor.tokenizer
saattrupdan Jul 31, 2023
ecd9f8e
style: Remove comment
saattrupdan Jul 31, 2023
cf07771
style: Remove tokenizer attribute from Processor protocol
saattrupdan Jul 31, 2023
dfc5575
chore: Rename script names to avoid name conflict
saattrupdan Jul 31, 2023
d22ae95
chore: Do not import `finetune` function in `__init__`
saattrupdan Jul 31, 2023
f6b2333
feat: Enable `save_safetensors`
saattrupdan Jul 31, 2023
7deef7d
chore: Update lock file
saattrupdan Jul 31, 2023
6be03a3
feat: Set up whisper finetuning
saattrupdan Jul 31, 2023
50cf83b
chore: More whisper configs, split up whisper/wav2vec2 specific options
saattrupdan Jul 31, 2023
92cccf8
chore: Only use sampling_rate in model config
saattrupdan Jul 31, 2023
d923798
feat: Add Nota config
saattrupdan Jul 31, 2023
a54715c
chore: Change config
saattrupdan Jul 31, 2023
582a54c
feat: Set up more dataset configs, and set seed in config
saattrupdan Jul 31, 2023
9e032a8
chore: Ignore wandb folder
saattrupdan Jul 31, 2023
748c787
docs: Update code coverage
saattrupdan Jul 31, 2023
8152ca4
chore: Update configs
saattrupdan Jul 31, 2023
d871012
tests: Add `finetune` test
saattrupdan Jul 31, 2023
a199df8
tests: Fix logging test
saattrupdan Jul 31, 2023
62a73ef
fix: Wandb import
saattrupdan Jul 31, 2023
d28000c
fix: In preprocessing, allow both `input_values` and `input_features`
saattrupdan Jul 31, 2023
8ab3d62
fix: Do not apply processor in wav2vec2 data collator, as it has been…
saattrupdan Jul 31, 2023
7871e38
docs: Update coverage badge
saattrupdan Jul 31, 2023
3058598
chore: Config renaming
saattrupdan Jul 31, 2023
4b47821
tests: Test whisper models too
saattrupdan Jul 31, 2023
071b0f3
docs: Update coverage badge
saattrupdan Jul 31, 2023
5252792
chore: Allow Python 3.11
saattrupdan Jul 31, 2023
b56a1ab
chore: Add Python 3.11 to CI
saattrupdan Jul 31, 2023
389b4c9
chore: Update lock file
saattrupdan Jul 31, 2023
d1e17dd
chore: KenLM installation
saattrupdan Jul 31, 2023
f611731
chore: Try using `use_auth_token` instead of `token`
saattrupdan Jul 31, 2023
a414038
fix: Default `token` to True in `train_ngram_model`
saattrupdan Jul 31, 2023
47c5494
chore: Revert back to `token`
saattrupdan Jul 31, 2023
d39a421
tests: Rename test dataset name to `test_dataset`
saattrupdan Jul 31, 2023
d42fefb
fix: Do not keep datasets in memory
saattrupdan Jul 31, 2023
1b2a67d
chore: Update configs
saattrupdan Aug 1, 2023
67b77af
chore: Update `make tree`
saattrupdan Aug 1, 2023
7c8269d
fix: Allow 2-dimensional inputs to compute_metrics, to cater to Whisp…
saattrupdan Aug 1, 2023
f19d9fb
chore: Disable `accelerate` logging
saattrupdan Aug 1, 2023
f887d40
fix: Add `mask_time_length` to Wav2Vec2ForCTC
saattrupdan Aug 1, 2023
93c30cb
fix: Add `mask_time_length` and `apply_spec_augment` to WhisperForCon…
saattrupdan Aug 1, 2023
44f0ead
tests: Add `mask_time_length` to test model configs
saattrupdan Aug 1, 2023
8cec6b5
tests: Add `dropout` to whisper test model config
saattrupdan Aug 1, 2023
51a3ce7
fix: Add `dropout` to whisper models config
saattrupdan Aug 1, 2023
04f92a8
feat: Disable tqdm during training
saattrupdan Aug 1, 2023
fef542d
chore: Change config
saattrupdan Aug 7, 2023
1421262
fix: Handle WANDB
saattrupdan Aug 7, 2023
7db7bb7
style: Load WhisperProcessor directly from pretrained
saattrupdan Aug 7, 2023
d8c15dc
tests: Disable fp16 while testing
saattrupdan Aug 7, 2023
e6093cd
fix: Config name typo
saattrupdan Aug 9, 2023
efecdf6
Merge branch 'main' of github.com:alexandrainst/CoRal-models into fea…
saattrupdan Aug 24, 2023
016d346
feat: Add ü to vocab
saattrupdan Aug 29, 2023
5c02af3
Merge branch 'main' into feat/add-whisper
saattrupdan Aug 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: [windows-latest, macos-latest, ubuntu-latest]
python-version: ["3.10"]
python-version: ["3.10", "3.11"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@ data/
models/

# Weights and Biases experiment tracking
wandb/
wandb/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ______________________________________________________________________
[![Documentation](https://img.shields.io/badge/docs-passing-green)](https://alexandrainst.github.io/CoRal-models/coral_models.html)
[![License](https://img.shields.io/github/license/alexandrainst/CoRal-models)](https://github.com/alexandrainst/CoRal-models/blob/main/LICENSE)
[![LastCommit](https://img.shields.io/github/last-commit/alexandrainst/CoRal-models)](https://github.com/alexandrainst/CoRal-models/commits/main)
[![Code Coverage](https://img.shields.io/badge/Coverage-53%25-orange.svg)](https://github.com/alexandrainst/CoRal-models/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-61%25-yellow.svg)](https://github.com/alexandrainst/CoRal-models/tree/main/tests)


Developers:
Expand Down
16 changes: 12 additions & 4 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@ dirs:
final: final
models: models

seed: 4242

# Model parameters
pipeline_id: ${model.name}-${dataset.name}
hub_id: alexandrainst/${pipeline_id}
model_dir: ${dirs.models}/${pipeline_id}
push_to_hub: false

# Data parameters
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789é|'

# Training parameters
resume_from_checkpoint: false
ignore_data_skip: false
wandb: false
wandb_name: default
wandb_project: CoRal
wandb_group: null
wandb_name: ${pipeline_id}
logging_steps: 10
eval_steps: 100
save_steps: 100
save_total_limit: 2
early_stopping: true
early_stopping_patience: 10
fp16: true
1 change: 0 additions & 1 deletion config/dataset/common_voice_da.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ train_name: train
val_name: validation
test_name: test
text_column: sentence
sampling_rate: 16_000
7 changes: 7 additions & 0 deletions config/dataset/common_voice_nn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: common_voice_nn
id: mozilla-foundation/common_voice_13_0
subset: nn-NO
train_name: train
val_name: validation
test_name: test
text_column: sentence
7 changes: 7 additions & 0 deletions config/dataset/common_voice_sv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: common_voice_sv
id: mozilla-foundation/common_voice_13_0
subset: sv-SE
train_name: train
val_name: validation
test_name: test
text_column: sentence
7 changes: 7 additions & 0 deletions config/dataset/fleurs_da.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: fleurs_da
id: google/fleurs
subset: da_dk
train_name: train
val_name: validation
test_name: test
text_column: raw_transcription
7 changes: 7 additions & 0 deletions config/dataset/fleurs_nb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: fleurs_nb
id: google/fleurs
subset: nb_no
train_name: train
val_name: validation
test_name: test
text_column: raw_transcription
7 changes: 7 additions & 0 deletions config/dataset/fleurs_sv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: fleurs_sv
id: google/fleurs
subset: sv_se
train_name: train
val_name: validation
test_name: test
text_column: raw_transcription
1 change: 0 additions & 1 deletion config/dataset/ftspeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ train_name: train
val_name: dev_balanced
test_name: test_balanced
text_column: sentence
sampling_rate: 16_000
7 changes: 7 additions & 0 deletions config/dataset/nota.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: nota
id: arpelarpe/nota
subset: null
train_name: train
val_name: null
test_name: null
text_column: sentence
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
name: test
name: test_dataset
id: alexandrainst/audio_test_dataset
subset: null
train_name: train
val_name: validation
test_name: test
text_column: sentence
sampling_rate: 16_000
11 changes: 6 additions & 5 deletions config/model/test.yaml → config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
name: wav2vec2-300m-ngram
name: test_wav2vec2
type: wav2vec2
pretrained_model_id: chcaa/xls-r-300m-danish
freeze_feature_encoder: true

# Data hyperparameters
clean_dataset: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789é|'

# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.1
Expand All @@ -11,6 +15,7 @@ hidden_dropout: 0.1
feat_proj_dropout: 0.1
final_dropout: 0.1
mask_time_prob: 0.075
mask_time_length: 10
mask_feature_prob: 0.075
mask_feature_length: 10
layerdrop: 0.1
Expand All @@ -28,7 +33,3 @@ warmup_steps: 1
early_stopping: true
early_stopping_patience: 5
fp16: false
eval_steps: 500
save_steps: 500
logging_steps: 100
save_total_limit: 2
28 changes: 28 additions & 0 deletions config/model/test_whisper.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: test_whisper
type: whisper
pretrained_model_id: openai/whisper-tiny
freeze_feature_encoder: true

# Data hyperparameters
clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 1
gradient_accumulation: 1
max_steps: 3
learning_rate: 4e-5
warmup_steps: 1
early_stopping: true
early_stopping_patience: 5
fp16: false
generation_max_length: 1
23 changes: 13 additions & 10 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
name: wav2vec2-300m-ngram
name: wav2vec2
type: wav2vec2
pretrained_model_id: chcaa/xls-r-300m-danish
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789é|ü'

# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.1
Expand All @@ -11,24 +15,23 @@ hidden_dropout: 0.1
feat_proj_dropout: 0.1
final_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1
ctc_loss_reduction: sum

# Decoder hyperparameters
language_model_decoder: null
language_model_decoder: ngram
decoder:
dataset_id: DDSC/reddit-da-asr-preprocessed
dataset_subset: null
dataset_split: train
n: 5

# Training hyperparameters
batch_size: 2
gradient_accumulation: 16
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 1000
early_stopping: true
early_stopping_patience: 10
fp16: true
eval_steps: 1000
save_steps: 1000
logging_steps: 100
save_total_limit: 2
warmup_steps: 500
39 changes: 0 additions & 39 deletions config/model/wav2vec2_with_lm.yaml

This file was deleted.

1 change: 0 additions & 1 deletion config/model/whisper.yaml

This file was deleted.

25 changes: 25 additions & 0 deletions config/model/whisper_large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: whisper_large
type: whisper
pretrained_model_id: openai/whisper-large-v2
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 1
gradient_accumulation: 32
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 500
generation_max_length: 225
25 changes: 25 additions & 0 deletions config/model/whisper_medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: whisper_medium
type: whisper
pretrained_model_id: openai/whisper-medium
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 8
gradient_accumulation: 4
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 500
generation_max_length: 225
25 changes: 25 additions & 0 deletions config/model/whisper_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: whisper_small
type: whisper
pretrained_model_id: openai/whisper-small
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 32
gradient_accumulation: 1
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 500
generation_max_length: 225
25 changes: 25 additions & 0 deletions config/model/whisper_xsmall.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: whisper_xsmall
type: whisper
pretrained_model_id: openai/whisper-base
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: false

# Model hyperparameters
sampling_rate: 16_000
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64

# Training hyperparameters
batch_size: 32
gradient_accumulation: 1
max_steps: 120_000
learning_rate: 3e-5
warmup_steps: 500
generation_max_length: 225
Loading