Skip to content

Commit

Permalink
created a property oov_token_index for CehrBertTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 6, 2024
1 parent 5ad6d58 commit 3622db4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
4 changes: 4 additions & 0 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def __init__(
def vocab_size(self):
return self._tokenizer.get_vocab_size()

@property
def oov_token_index(self):
return self._oov_token_index

@property
def mask_token_index(self):
return self._mask_token_index
Expand Down
48 changes: 24 additions & 24 deletions src/cehrbert/runners/runner_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def load_parquet_as_dataset(
data_folder, split="train", streaming=False
data_folder, split="train", streaming=False
) -> Union[Dataset, IterableDataset]:
"""
Loads a dataset from Parquet files located within a specified folder into a Hugging Face `datasets.Dataset`.
Expand Down Expand Up @@ -101,14 +101,14 @@ def get_last_hf_checkpoint(training_args):
last_checkpoint = None
output_dir_abspath = os.path.abspath(training_args.output_dir)
if (
os.path.isdir(output_dir_abspath)
and training_args.do_train
and not training_args.overwrite_output_dir
os.path.isdir(output_dir_abspath)
and training_args.do_train
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(output_dir_abspath)
if (
last_checkpoint is None
and len([_ for _ in os.listdir(output_dir_abspath) if os.path.isdir(_)]) > 0
last_checkpoint is None
and len([_ for _ in os.listdir(output_dir_abspath) if os.path.isdir(_)]) > 0
):
raise ValueError(
f"Output directory ({output_dir_abspath}) already exists and is not empty. "
Expand Down Expand Up @@ -196,28 +196,28 @@ def generate_prepared_ds_path(data_args, model_args, data_folder=None) -> Path:
"""
data_folder = data_folder if data_folder else data_args.data_folder
concatenated_str = (
str(model_args.max_position_embeddings)
+ "|"
+ os.path.abspath(data_folder)
+ "|"
+ os.path.abspath(model_args.tokenizer_name_or_path)
+ "|"
+ (
str(data_args.validation_split_percentage)
if data_args.validation_split_percentage
else ""
)
+ "|"
+ f"test_eval_ratio={str(data_args.test_eval_ratio)}"
+ "|"
+ f"split_by_patient={str(data_args.split_by_patient)}"
+ "|"
+ f"chronological_split={str(data_args.chronological_split)}"
str(model_args.max_position_embeddings)
+ "|"
+ os.path.abspath(data_folder)
+ "|"
+ os.path.abspath(model_args.tokenizer_name_or_path)
+ "|"
+ (
str(data_args.validation_split_percentage)
if data_args.validation_split_percentage
else ""
)
+ "|"
+ f"test_eval_ratio={str(data_args.test_eval_ratio)}"
+ "|"
+ f"split_by_patient={str(data_args.split_by_patient)}"
+ "|"
+ f"chronological_split={str(data_args.chronological_split)}"
)
basename = os.path.basename(data_folder)
cleaned_basename = re.sub(r"[^a-zA-Z0-9_]", "", basename)
LOG.info(f"concatenated_str: {concatenated_str}")
ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str, usedforsecurity=False))}"
ds_hash = f"{cleaned_basename}_{str(md5(concatenated_str))}"
LOG.info(f"ds_hash: {ds_hash}")
prepared_ds_path = Path(os.path.abspath(data_args.dataset_prepared_path)) / ds_hash
return prepared_ds_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_convert_tokens_to_string(self):
def test_oov_token(self):
# Test the encoding of an out-of-vocabulary token
encoded = self.tokenizer.encode(["nonexistent"])
self.assertEqual(encoded, [self.tokenizer._oov_token_index])
self.assertEqual(encoded, [self.tokenizer.oov_token_index])

def test_convert_id_to_token_oov(self):
# Test decoding an out-of-vocabulary token ID
Expand Down

0 comments on commit 3622db4

Please sign in to comment.