diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 64880cce..166966c7 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -256,6 +256,13 @@ def _create_cehrbert_data_from_meds( writer_batch_size=data_args.preprocessing_batch_size, streaming=data_args.streaming, ) + # Remove patients without any records + dataset = dataset.filter( + lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_visits"]], + num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None, + batched=True, + batch_size=data_args.preprocessing_batch_size, + ) # Convert the CehrBertPatient to CehrBert data inputs dataset = apply_cehrbert_dataset_mapping( dataset,