From fa6d0aed881422a5dd51a99c23477d7f2b155634 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Wed, 9 Oct 2024 23:50:08 -0400 Subject: [PATCH] used the existing subject_split to split the cohort automatically instead of doing it manually (#63) --- .../data_generators/hf_data_generator/meds_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 4f975812..dcdb30e5 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -436,8 +436,12 @@ def _create_cehrbert_data_from_meds( assert split in ["held_out", "train", "tuning"] batches = [] if data_args.cohort_folder: - cohort = pd.read_parquet(os.path.join(os.path.expanduser(data_args.cohort_folder), split)) - for cohort_row in cohort.itertuples(): + # Load the entire cohort + cohort = pd.read_parquet(os.path.expanduser(data_args.cohort_folder)) + patient_split = get_subject_split(os.path.expanduser(data_args.data_folder)) + subject_ids = patient_split[split] + cohort_split = cohort[cohort.subject_id.isin(subject_ids)] + for cohort_row in cohort_split.itertuples(): subject_id = cohort_row.subject_id prediction_time = cohort_row.prediction_time label = int(cohort_row.boolean_value)