From 12049f1fe3e693db45655d8b8a556971abfd2adc Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Wed, 9 Oct 2024 12:35:30 -0400 Subject: [PATCH] switched all absolute paths to expanduser paths --- .../hf_data_generator/meds_utils.py | 6 +++--- .../runners/hf_cehrbert_finetune_runner.py | 6 +++--- .../runners/hf_cehrbert_pretrain_runner.py | 6 +++--- src/cehrbert/runners/runner_util.py | 18 +++++++++--------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 398e4526..f064fe08 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -433,21 +433,21 @@ def _create_cehrbert_data_from_meds( assert split in ["held_out", "train", "tuning"] batches = [] if data_args.cohort_folder: - cohort = pd.read_parquet(os.path.join(os.path.abspath(data_args.cohort_folder), split)) + cohort = pd.read_parquet(os.path.join(os.path.expanduser(data_args.cohort_folder), split)) for cohort_row in cohort.itertuples(): subject_id = cohort_row.subject_id prediction_time = cohort_row.prediction_time label = int(cohort_row.boolean_value) batches.append((subject_id, prediction_time, label)) else: - patient_split = get_subject_split(os.path.abspath(data_args.data_folder)) + patient_split = get_subject_split(os.path.expanduser(data_args.data_folder)) for subject_id in patient_split[split]: batches.append((subject_id, None, None)) split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers) batch_func = functools.partial( _meds_to_cehrbert_generator, - path_to_db=os.path.abspath(data_args.data_folder), + path_to_db=os.path.expanduser(data_args.data_folder), default_visit_id=default_visit_id, meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type, ) diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 7fe1dbd7..ee380295 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -103,8 +103,8 @@ def main(): # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format if data_args.is_data_in_med: meds_extension_path = get_meds_extension_path( - data_folder=os.path.abspath(data_args.cohort_folder), - dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path), + data_folder=os.path.expanduser(data_args.cohort_folder), + dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path), ) try: LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...") @@ -126,7 +126,7 @@ def main(): validation_set = dataset["validation"] test_set = dataset["test"] else: - dataset = load_parquet_as_dataset(os.path.abspath(data_args.data_folder)) + dataset = load_parquet_as_dataset(os.path.expanduser(data_args.data_folder)) test_set = None if data_args.test_data_folder: test_set = load_parquet_as_dataset(data_args.test_data_folder) diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index c9b8cb20..a0d9b80f 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -169,8 +169,8 @@ def main(): # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format if data_args.is_data_in_med: meds_extension_path = get_meds_extension_path( - data_folder=os.path.abspath(data_args.data_folder), - dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path), + data_folder=os.path.expanduser(data_args.data_folder), + dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path), ) try: LOG.info( @@ -194,7 +194,7 @@ def main(): else: # Load the dataset from the parquet files dataset = load_parquet_as_dataset( - os.path.abspath(data_args.data_folder), split="train", streaming=data_args.streaming + os.path.expanduser(data_args.data_folder), split="train", streaming=data_args.streaming ) # If streaming is enabled, we need to manually split the data into train/val if data_args.streaming and data_args.validation_split_num: diff --git a/src/cehrbert/runners/runner_util.py b/src/cehrbert/runners/runner_util.py index e6cf0f5e..702404fa 100644 --- a/src/cehrbert/runners/runner_util.py +++ b/src/cehrbert/runners/runner_util.py @@ -48,7 +48,7 @@ def load_parquet_as_dataset(data_folder, split="train", streaming=False) -> Unio files differ in schema or are meant to represent different splits, separate calls and directory structuring are advised. """ - data_abspath = os.path.abspath(data_folder) + data_abspath = os.path.expanduser(data_folder) data_files = glob.glob(os.path.join(data_abspath, "*.parquet")) dataset = load_dataset("parquet", data_files=data_files, split=split, streaming=streaming) return dataset @@ -89,7 +89,7 @@ def get_last_hf_checkpoint(training_args): ... ) >>> last_checkpoint = get_last_hf_checkpoint(training_args) >>> print(last_checkpoint) - '/absolute/path/to/results/checkpoint-500' + '/path/to/results/checkpoint-500' Note: If `last_checkpoint` is detected and `resume_from_checkpoint` is None, training will automatically @@ -174,12 +174,12 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: >>> model_args = ModelArguments(max_position_embeddings=512, tokenizer_name_or_path='bert-base-uncased') >>> path = generate_prepared_ds_path(data_args, model_args) >>> print(path) - PosixPath('/absolute/path/to/prepared/datafoldername_hash') + PosixPath('/path/to/prepared/datafoldername_hash') Note: The hash is generated from a combination of the following: - model_args.max_position_embeddings - - Absolute paths of `data_folder` and `model_args.tokenizer_name_or_path` + - paths of `data_folder` and `model_args.tokenizer_name_or_path` - `data_args.validation_split_percentage` (if provided) - `data_args.test_eval_ratio`, `data_args.split_by_patient`, and `data_args.chronological_split` @@ -189,9 +189,9 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: concatenated_str = ( str(model_args.max_position_embeddings) + "|" - + os.path.abspath(data_folder) + + os.path.expanduser(data_folder) + "|" - + os.path.abspath(model_args.tokenizer_name_or_path) + + os.path.expanduser(model_args.tokenizer_name_or_path) + "|" + (str(data_args.validation_split_percentage) if data_args.validation_split_percentage else "") + "|" @@ -206,7 +206,7 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: LOG.info(f"concatenated_str: {concatenated_str}") ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str))}" LOG.info(f"ds_hash: {ds_hash}") - prepared_ds_path = Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash + prepared_ds_path = Path(os.path.expanduser(data_args.dataset_prepared_path)) / ds_hash return prepared_ds_path @@ -252,9 +252,9 @@ def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, Training if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1])) elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1])) + model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() return data_args, model_args, training_args