Skip to content

Commit

Permalink
switched all absolute paths to expanduser paths
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 9, 2024
1 parent e24584d commit 12049f1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,21 +433,21 @@ def _create_cehrbert_data_from_meds(
assert split in ["held_out", "train", "tuning"]
batches = []
if data_args.cohort_folder:
cohort = pd.read_parquet(os.path.join(os.path.abspath(data_args.cohort_folder), split))
cohort = pd.read_parquet(os.path.join(os.path.expanduser(data_args.cohort_folder), split))
for cohort_row in cohort.itertuples():
subject_id = cohort_row.subject_id
prediction_time = cohort_row.prediction_time
label = int(cohort_row.boolean_value)
batches.append((subject_id, prediction_time, label))
else:
patient_split = get_subject_split(os.path.abspath(data_args.data_folder))
patient_split = get_subject_split(os.path.expanduser(data_args.data_folder))
for subject_id in patient_split[split]:
batches.append((subject_id, None, None))

split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers)
batch_func = functools.partial(
_meds_to_cehrbert_generator,
path_to_db=os.path.abspath(data_args.data_folder),
path_to_db=os.path.expanduser(data_args.data_folder),
default_visit_id=default_visit_id,
meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type,
)
Expand Down
6 changes: 3 additions & 3 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def main():
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
meds_extension_path = get_meds_extension_path(
data_folder=os.path.abspath(data_args.cohort_folder),
dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path),
data_folder=os.path.expanduser(data_args.cohort_folder),
dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path),
)
try:
LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...")
Expand All @@ -126,7 +126,7 @@ def main():
validation_set = dataset["validation"]
test_set = dataset["test"]
else:
dataset = load_parquet_as_dataset(os.path.abspath(data_args.data_folder))
dataset = load_parquet_as_dataset(os.path.expanduser(data_args.data_folder))
test_set = None
if data_args.test_data_folder:
test_set = load_parquet_as_dataset(data_args.test_data_folder)
Expand Down
6 changes: 3 additions & 3 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def main():
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
meds_extension_path = get_meds_extension_path(
data_folder=os.path.abspath(data_args.data_folder),
dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path),
data_folder=os.path.expanduser(data_args.data_folder),
dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path),
)
try:
LOG.info(
Expand All @@ -194,7 +194,7 @@ def main():
else:
# Load the dataset from the parquet files
dataset = load_parquet_as_dataset(
os.path.abspath(data_args.data_folder), split="train", streaming=data_args.streaming
os.path.expanduser(data_args.data_folder), split="train", streaming=data_args.streaming
)
# If streaming is enabled, we need to manually split the data into train/val
if data_args.streaming and data_args.validation_split_num:
Expand Down
18 changes: 9 additions & 9 deletions src/cehrbert/runners/runner_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def load_parquet_as_dataset(data_folder, split="train", streaming=False) -> Unio
files differ in schema or are meant to represent different splits, separate calls and directory structuring
are advised.
"""
data_abspath = os.path.abspath(data_folder)
data_abspath = os.path.expanduser(data_folder)
data_files = glob.glob(os.path.join(data_abspath, "*.parquet"))
dataset = load_dataset("parquet", data_files=data_files, split=split, streaming=streaming)
return dataset
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_last_hf_checkpoint(training_args):
... )
>>> last_checkpoint = get_last_hf_checkpoint(training_args)
>>> print(last_checkpoint)
'/absolute/path/to/results/checkpoint-500'
'/path/to/results/checkpoint-500'
Note:
If `last_checkpoint` is detected and `resume_from_checkpoint` is None, training will automatically
Expand Down Expand Up @@ -174,12 +174,12 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path:
>>> model_args = ModelArguments(max_position_embeddings=512, tokenizer_name_or_path='bert-base-uncased')
>>> path = generate_prepared_ds_path(data_args, model_args)
>>> print(path)
PosixPath('/absolute/path/to/prepared/datafoldername_hash')
PosixPath('/path/to/prepared/datafoldername_hash')
Note:
The hash is generated from a combination of the following:
- model_args.max_position_embeddings
- Absolute paths of `data_folder` and `model_args.tokenizer_name_or_path`
- paths of `data_folder` and `model_args.tokenizer_name_or_path`
- `data_args.validation_split_percentage` (if provided)
- `data_args.test_eval_ratio`, `data_args.split_by_patient`, and `data_args.chronological_split`
Expand All @@ -189,9 +189,9 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path:
concatenated_str = (
str(model_args.max_position_embeddings)
+ "|"
+ os.path.abspath(data_folder)
+ os.path.expanduser(data_folder)
+ "|"
+ os.path.abspath(model_args.tokenizer_name_or_path)
+ os.path.expanduser(model_args.tokenizer_name_or_path)
+ "|"
+ (str(data_args.validation_split_percentage) if data_args.validation_split_percentage else "")
+ "|"
Expand All @@ -206,7 +206,7 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path:
LOG.info(f"concatenated_str: {concatenated_str}")
ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str))}"
LOG.info(f"ds_hash: {ds_hash}")
prepared_ds_path = Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash
prepared_ds_path = Path(os.path.expanduser(data_args.dataset_prepared_path)) / ds_hash
return prepared_ds_path


Expand Down Expand Up @@ -252,9 +252,9 @@ def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, Training
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
return data_args, model_args, training_args
Expand Down

0 comments on commit 12049f1

Please sign in to comment.