Skip to content

Commit

Permalink
added filter to remove the short sequences from the training/validati…
Browse files Browse the repository at this point in the history
…on datasets in pre-training
  • Loading branch information
ChaoPang committed Oct 11, 2024
1 parent 4495f67 commit 8c40a27
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,23 @@ def main():
if not data_args.streaming:
processed_dataset.save_to_disk(prepared_ds_path)

def filter_func(examples):
return [_ >= data_args.min_num_tokens for _ in examples["num_of_concepts"]]

# Create the args for batched filtering
filter_args = {"batched": True, "batch_size": data_args.preprocessing_batch_size}
# If the dataset is not in a streaming mode, we could add num_proc to enable parallelization
if not data_args.streaming:
filter_args["num_proc"] = data_args.preprocessing_num_workers

# The filter can't be applied to a DatasetDict of IterableDataset (in case of streaming)
# we need to iterate through all the datasets and apply the filter separately
if isinstance(processed_dataset, DatasetDict) or isinstance(processed_dataset, IterableDatasetDict):
for key in processed_dataset.keys():
processed_dataset[key] = processed_dataset[key].filter(filter_func, **filter_args)
else:
processed_dataset = processed_dataset.filter(filter_func, **filter_args)

model = load_and_create_model(model_args, tokenizer)

collator = CehrBertDataCollator(
Expand Down

0 comments on commit 8c40a27

Please sign in to comment.