diff --git a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py index b51cbc38..783007b7 100644 --- a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py @@ -84,6 +84,10 @@ def __init__( def vocab_size(self): return self._tokenizer.get_vocab_size() + @property + def oov_token_index(self): + return self._oov_token_index + @property def mask_token_index(self): return self._mask_token_index diff --git a/src/cehrbert/runners/runner_util.py b/src/cehrbert/runners/runner_util.py index 9ee87c38..d2b7bad8 100644 --- a/src/cehrbert/runners/runner_util.py +++ b/src/cehrbert/runners/runner_util.py @@ -19,7 +19,7 @@ def load_parquet_as_dataset( - data_folder, split="train", streaming=False + data_folder, split="train", streaming=False ) -> Union[Dataset, IterableDataset]: """ Loads a dataset from Parquet files located within a specified folder into a Hugging Face `datasets.Dataset`. @@ -101,14 +101,14 @@ def get_last_hf_checkpoint(training_args): last_checkpoint = None output_dir_abspath = os.path.abspath(training_args.output_dir) if ( - os.path.isdir(output_dir_abspath) - and training_args.do_train - and not training_args.overwrite_output_dir + os.path.isdir(output_dir_abspath) + and training_args.do_train + and not training_args.overwrite_output_dir ): last_checkpoint = get_last_checkpoint(output_dir_abspath) if ( - last_checkpoint is None - and len([_ for _ in os.listdir(output_dir_abspath) if os.path.isdir(_)]) > 0 + last_checkpoint is None + and len([_ for _ in os.listdir(output_dir_abspath) if os.path.isdir(_)]) > 0 ): raise ValueError( f"Output directory ({output_dir_abspath}) already exists and is not empty. " @@ -196,28 +196,28 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path: """ data_folder = data_folder if data_folder else data_args.data_folder concatenated_str = ( - str(model_args.max_position_embeddings) - + "|" - + os.path.abspath(data_folder) - + "|" - + os.path.abspath(model_args.tokenizer_name_or_path) - + "|" - + ( - str(data_args.validation_split_percentage) - if data_args.validation_split_percentage - else "" - ) - + "|" - + f"test_eval_ratio={str(data_args.test_eval_ratio)}" - + "|" - + f"split_by_patient={str(data_args.split_by_patient)}" - + "|" - + f"chronological_split={str(data_args.chronological_split)}" + str(model_args.max_position_embeddings) + + "|" + + os.path.abspath(data_folder) + + "|" + + os.path.abspath(model_args.tokenizer_name_or_path) + + "|" + + ( + str(data_args.validation_split_percentage) + if data_args.validation_split_percentage + else "" + ) + + "|" + + f"test_eval_ratio={str(data_args.test_eval_ratio)}" + + "|" + + f"split_by_patient={str(data_args.split_by_patient)}" + + "|" + + f"chronological_split={str(data_args.chronological_split)}" ) basename = os.path.basename(data_folder) cleaned_basename = re.sub(r"[^a-zA-Z0-9_]", "", basename) LOG.info(f"concatenated_str: {concatenated_str}") - ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str, usedforsecurity=False))}" + ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str))}" LOG.info(f"ds_hash: {ds_hash}") prepared_ds_path = Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash return prepared_ds_path diff --git a/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py b/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py index ddc5ffa3..c8c74345 100644 --- a/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py +++ b/tests/unit_tests/models/hf_models/tokenization_hf_cehrbert_test.py @@ -58,7 +58,7 @@ def test_convert_tokens_to_string(self): def test_oov_token(self): # Test the encoding of an out-of-vocabulary token encoded = self.tokenizer.encode(["nonexistent"]) - self.assertEqual(encoded, [self.tokenizer._oov_token_index]) + self.assertEqual(encoded, [self.tokenizer.oov_token_index]) def test_convert_id_to_token_oov(self): # Test decoding an out-of-vocabulary token ID