Skip to content

Commit

Permalink
converted the datat_folder and dataset_prepared_path to absolute path…
Browse files Browse the repository at this point in the history
…s; converted tokenizer_or_model_path to expanded_user path
  • Loading branch information
ChaoPang committed Oct 9, 2024
1 parent dcf9715 commit e24584d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,21 +433,21 @@ def _create_cehrbert_data_from_meds(
assert split in ["held_out", "train", "tuning"]
batches = []
if data_args.cohort_folder:
cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split))
cohort = pd.read_parquet(os.path.join(os.path.abspath(data_args.cohort_folder), split))
for cohort_row in cohort.itertuples():
subject_id = cohort_row.subject_id
prediction_time = cohort_row.prediction_time
label = int(cohort_row.boolean_value)
batches.append((subject_id, prediction_time, label))
else:
patient_split = get_subject_split(data_args.data_folder)
patient_split = get_subject_split(os.path.abspath(data_args.data_folder))
for subject_id in patient_split[split]:
batches.append((subject_id, None, None))

split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers)
batch_func = functools.partial(
_meds_to_cehrbert_generator,
path_to_db=data_args.data_folder,
path_to_db=os.path.abspath(data_args.data_folder),
default_visit_id=default_visit_id,
meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type,
)
Expand Down
7 changes: 4 additions & 3 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Union
Expand Down Expand Up @@ -102,8 +103,8 @@ def main():
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
meds_extension_path = get_meds_extension_path(
data_folder=data_args.cohort_folder,
dataset_prepared_path=data_args.dataset_prepared_path,
data_folder=os.path.abspath(data_args.cohort_folder),
dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path),
)
try:
LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...")
Expand All @@ -125,7 +126,7 @@ def main():
validation_set = dataset["validation"]
test_set = dataset["test"]
else:
dataset = load_parquet_as_dataset(data_args.data_folder)
dataset = load_parquet_as_dataset(os.path.abspath(data_args.data_folder))
test_set = None
if data_args.test_data_folder:
test_set = load_parquet_as_dataset(data_args.test_data_folder)
Expand Down
26 changes: 13 additions & 13 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,22 @@ def load_and_create_tokenizer(
tokenizer = load_and_create_tokenizer(data_args, model_args, dataset)
"""
# Try to load the pretrained tokenizer
tokenizer_name_or_path = os.path.expanduser(model_args.tokenizer_name_or_path)
try:
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_name_or_path)
except (OSError, RuntimeError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(
"Failed to load the tokenizer from %s with the error "
"\n%s\nTried to create the tokenizer, however the dataset is not provided.",
model_args.tokenizer_name_or_path,
tokenizer_name_or_path,
e,
)
if dataset is None:
raise e
tokenizer = CehrBertTokenizer.train_tokenizer(
dataset, feature_names=["concept_ids"], concept_name_mapping={}, data_args=data_args
)
tokenizer.save_pretrained(model_args.tokenizer_name_or_path)
tokenizer.save_pretrained(tokenizer_name_or_path)

return tokenizer

Expand All @@ -93,7 +94,7 @@ def load_and_create_model(model_args: ModelArguments, tokenizer: CehrBertTokeniz
model = load_and_create_model(model_args, tokenizer)
"""
try:
model_config = AutoConfig.from_pretrained(model_args.model_name_or_path)
model_config = AutoConfig.from_pretrained(os.path.expanduser(model_args.model_name_or_path))
except (OSError, ValueError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(e)
model_config = CehrBertConfig(
Expand Down Expand Up @@ -168,8 +169,8 @@ def main():
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
meds_extension_path = get_meds_extension_path(
data_folder=data_args.data_folder,
dataset_prepared_path=data_args.dataset_prepared_path,
data_folder=os.path.abspath(data_args.data_folder),
dataset_prepared_path=os.path.abspath(data_args.dataset_prepared_path),
)
try:
LOG.info(
Expand All @@ -180,22 +181,21 @@ def main():
if data_args.streaming:
if isinstance(dataset, DatasetDict):
dataset = {
k: v.to_iterable_dataset(
num_shards=training_args.dataloader_num_workers
) for k, v in dataset.items()
k: v.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
for k, v in dataset.items()
}
else:
dataset = dataset.to_iterable_dataset(
num_shards=training_args.dataloader_num_workers
)
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
except FileNotFoundError as e:
LOG.exception(e)
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True)
if not data_args.streaming:
dataset.save_to_disk(meds_extension_path)
else:
# Load the dataset from the parquet files
dataset = load_parquet_as_dataset(data_args.data_folder, split="train", streaming=data_args.streaming)
dataset = load_parquet_as_dataset(
os.path.abspath(data_args.data_folder), split="train", streaming=data_args.streaming
)
# If streaming is enabled, we need to manually split the data into train/val
if data_args.streaming and data_args.validation_split_num:
dataset = dataset.shuffle(buffer_size=10_000, seed=training_args.seed)
Expand Down

0 comments on commit e24584d

Please sign in to comment.