diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 4f975812..dcdb30e5 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -436,8 +436,12 @@ def _create_cehrbert_data_from_meds( assert split in ["held_out", "train", "tuning"] batches = [] if data_args.cohort_folder: - cohort = pd.read_parquet(os.path.join(os.path.expanduser(data_args.cohort_folder), split)) - for cohort_row in cohort.itertuples(): + # Load the entire cohort + cohort = pd.read_parquet(os.path.expanduser(data_args.cohort_folder)) + patient_split = get_subject_split(os.path.expanduser(data_args.data_folder)) + subject_ids = patient_split[split] + cohort_split = cohort[cohort.subject_id.isin(subject_ids)] + for cohort_row in cohort_split.itertuples(): subject_id = cohort_row.subject_id prediction_time = cohort_row.prediction_time label = int(cohort_row.boolean_value)