Skip to content

Commit

Permalink
Cast index_date to the Unix Epoch time in HFFineTuningMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 11, 2024
1 parent a19a370 commit 51cde66
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def __call__(self, examples):

if "index_date" in examples[0]:
batch["index_date"] = torch.cat(
[self._convert_to_tensor(example["index_date"].timestamp()).reshape(-1, 1) for example in examples],
[self._convert_to_tensor(example["index_date"]).reshape(-1, 1) for example in examples],
dim=0,
).to(torch.long)
).to(torch.float32)

if "age_at_index" in examples[0]:
batch["age_at_index"] = torch.cat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ class MedToCehrBertDatasetMapping(DatasetMapping):
def __init__(self, data_args: DataTrainingArguments, is_pretraining: bool = True):
self._time_token_function = get_att_function(data_args.att_function_type)
self._include_auxiliary_token = data_args.include_auxiliary_token
self._inpatient_time_token_function = get_att_function(
data_args.inpatient_att_function_type)
self._inpatient_time_token_function = get_att_function(data_args.inpatient_att_function_type)
self._include_demographic_prompt = data_args.include_demographic_prompt
self._is_pretraining = is_pretraining

Expand Down Expand Up @@ -128,16 +127,16 @@ def remove_columns(self):

@staticmethod
def _update_cehrbert_record(
cehrbert_record: Dict[str, Any],
code: str,
visit_segment: int = 0,
date: int = 0,
age: int = -1,
visit_concept_order: int = 0,
visit_concept_id: str = "0",
concept_value_mask: int = 0,
concept_value: float = -1.0,
mlm_skip_value: int = 0,
cehrbert_record: Dict[str, Any],
code: str,
visit_segment: int = 0,
date: int = 0,
age: int = -1,
visit_concept_order: int = 0,
visit_concept_id: str = "0",
concept_value_mask: int = 0,
concept_value: float = -1.0,
mlm_skip_value: int = 0,
) -> None:
cehrbert_record["concept_ids"].append(code)
cehrbert_record["visit_concept_orders"].append(visit_concept_order)
Expand Down Expand Up @@ -188,8 +187,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
date_cursor = None

# Loop through all the visits excluding the first event containing the demographics
for i, visit in enumerate(
sorted(record["visits"], key=lambda e: e["visit_start_datetime"])):
for i, visit in enumerate(sorted(record["visits"], key=lambda e: e["visit_start_datetime"])):

events = visit["events"]

Expand Down Expand Up @@ -217,8 +215,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
# Add the VS token to the patient timeline to mark the start of a visit
age = relativedelta(visit["visit_start_datetime"], birth_datetime).years
# Calculate the week number since the epoch time
date = (visit["visit_start_datetime"] - datetime.datetime(year=1970, month=1,
day=1)).days // 7
date = (visit["visit_start_datetime"] - datetime.datetime(year=1970, month=1, day=1)).days // 7
visit_segment = int(visit_segment_indicator) + 1

self._update_cehrbert_record(
Expand Down Expand Up @@ -390,8 +387,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
column_names.append(k)
column_values.append(v)

sorted_list = sorted(zip(sorting_columns, *column_values),
key=lambda tup2: (tup2[0], tup2[1]))
sorted_list = sorted(zip(sorting_columns, *column_values), key=lambda tup2: (tup2[0], tup2[1]))

# uses a combination of zip() and unpacking (*) to transpose the list of tuples. This means converting rows
# into columns: the first tuple formed from all the first elements of the sorted tuples, the second tuple
Expand Down Expand Up @@ -425,10 +421,10 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
if np.any(np.asarray(concept_value_masks) > 0):
normalized_concept_values = copy.deepcopy(concept_values)
for i, (
concept_id,
token_id,
concept_value_mask,
concept_value,
concept_id,
token_id,
concept_value_mask,
concept_value,
) in enumerate(
zip(
record["concept_ids"],
Expand All @@ -438,8 +434,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
)
):
if token_id in self._lab_token_ids:
normalized_concept_value = self._concept_tokenizer.normalize(concept_id,
concept_value)
normalized_concept_value = self._concept_tokenizer.normalize(concept_id, concept_value)
normalized_concept_values[i] = normalized_concept_value
record["concept_values"] = normalized_concept_values

Expand All @@ -458,6 +453,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
return {
"age_at_index": record["age"] if "age" in record else record["age_at_index"],
"classifier_label": record["label"],
"index_date": record["index_date"].timestamp() if "index_date" in record else None,
}

def remove_columns(self):
Expand Down

0 comments on commit 51cde66

Please sign in to comment.