From 8c40a27d7f541e332df9e01cc662d49188f884a0 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Fri, 11 Oct 2024 06:47:46 -0400 Subject: [PATCH] added filter to remove the short sequences from the training/validation datasets in pre-training --- .../runners/hf_cehrbert_pretrain_runner.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index 08068b4d..bb83ca3e 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -226,6 +226,23 @@ def main(): if not data_args.streaming: processed_dataset.save_to_disk(prepared_ds_path) + def filter_func(examples): + return [_ >= data_args.min_num_tokens for _ in examples["num_of_concepts"]] + + # Create the args for batched filtering + filter_args = {"batched": True, "batch_size": data_args.preprocessing_batch_size} + # If the dataset is not in a streaming mode, we could add num_proc to enable parallelization + if not data_args.streaming: + filter_args["num_proc"] = data_args.preprocessing_num_workers + + # The filter can't be applied to a DatasetDict of IterableDataset (in case of streaming) + # we need to iterate through all the datasets and apply the filter separately + if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict): + for key in processed_dataset.keys(): + processed_dataset[key] = processed_dataset[key].filter(filter_func, **filter_args) + else: + processed_dataset = processed_dataset.filter(filter_func, **filter_args) + model = load_and_create_model(model_args, tokenizer) collator = CehrBertDataCollator(