Skip to content

Commit

Permalink
Added units to meds to cehrbert data transformation and added unit in…
Browse files Browse the repository at this point in the history
… normalizing numeric values (#60)

* added units to data transformation and added unit in normalizing the numeric values

* moved the constant NA from tokenization_hf_cehrbert.py to tokenization_utils.py

* upgraded cehrbert_data to v0.0.3
  • Loading branch information
ChaoPang authored Oct 9, 2024
1 parent 7bdc753 commit 4ba50e2
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ dependencies = [
"Werkzeug==3.0.1",
"wandb==0.17.8",
"xgboost==2.0.3",
"cehrbert_data==0.0.1"
"cehrbert_data==0.0.3"
]

[tool.setuptools_scm]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import numpy as np
import pandas as pd
from cehrbert_data.decorators.patient_event_decorator import get_att_function
from cehrbert_data.const.common import NA
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
from datasets.formatting.formatting import LazyBatch
from dateutil.relativedelta import relativedelta
from meds.schema import birth_code, death_code
Expand Down Expand Up @@ -137,6 +138,7 @@ def _update_cehrbert_record(
concept_value_mask: int = 0,
concept_value: float = -1.0,
mlm_skip_value: int = 0,
unit: str = NA,
) -> None:
cehrbert_record["concept_ids"].append(code)
cehrbert_record["visit_concept_orders"].append(visit_concept_order)
Expand All @@ -146,6 +148,7 @@ def _update_cehrbert_record(
cehrbert_record["visit_concept_ids"].append(visit_concept_id)
cehrbert_record["concept_value_masks"].append(concept_value_mask)
cehrbert_record["concept_values"].append(concept_value)
cehrbert_record["units"].append(unit)
cehrbert_record["mlm_skip_values"].append(mlm_skip_value)

def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -160,6 +163,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
"visit_concept_orders": [],
"concept_value_masks": [],
"concept_values": [],
"units": [],
"mlm_skip_values": [],
"visit_concept_ids": [],
}
Expand Down Expand Up @@ -274,6 +278,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:

# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
numeric_value = e.get("numeric_value", None)
unit = e.get("unit", NA)
concept_value_mask = int(numeric_value is not None)
concept_value = numeric_value if concept_value_mask == 1 else -1.0
code = replace_escape_chars(e["code"])
Expand All @@ -295,6 +300,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
visit_concept_id=visit_type,
concept_value_mask=concept_value_mask,
concept_value=concept_value,
unit=unit,
mlm_skip_value=concept_value_mask,
)

Expand Down Expand Up @@ -419,22 +425,25 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:

# If any concept has a value associated with it, we normalize the value
if np.any(np.asarray(concept_value_masks) > 0):
units = record["units"]
normalized_concept_values = copy.deepcopy(concept_values)
for i, (
concept_id,
unit,
token_id,
concept_value_mask,
concept_value,
) in enumerate(
zip(
record["concept_ids"],
units,
input_ids,
concept_value_masks,
concept_values,
)
):
if token_id in self._lab_token_ids:
normalized_concept_value = self._concept_tokenizer.normalize(concept_id, concept_value)
normalized_concept_value = self._concept_tokenizer.normalize(concept_id, unit, concept_value)
normalized_concept_values[i] = normalized_concept_value
record["concept_values"] = normalized_concept_values

Expand Down
3 changes: 3 additions & 0 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _convert_event(self, event) -> List[Event]:
time = getattr(event, "time", None)
text_value = getattr(event, "text_value", None)
numeric_value = getattr(event, "numeric_value", None)
unit = getattr(event, "unit", None)

if numeric_value is None and text_value is not None:
conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code)
Expand All @@ -188,6 +189,7 @@ def _convert_event(self, event) -> List[Event]:
code=label,
time=time,
numeric_value=float(value),
unit=unit,
properties={"visit_id": self.visit_id, "table": "meds"},
)
for label, value in zip(conversion_rule.mapped_event_labels, match.groups())
Expand All @@ -200,6 +202,7 @@ def _convert_event(self, event) -> List[Event]:
code=code,
time=time,
numeric_value=numeric_value,
unit=unit,
text_value=text_value,
properties={"visit_id": self.visit_id, "table": "meds"},
)
Expand Down
1 change: 1 addition & 0 deletions src/cehrbert/med_extension/schema_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"code": str,
"text_value": NotRequired[Optional[str]],
"numeric_value": NotRequired[Optional[float]],
"unit": NotRequired[Optional[str]],
"datetime_value": NotRequired[datetime.datetime],
"properties": NotRequired[Optional[Mapping[str, Any]]],
},
Expand Down
112 changes: 84 additions & 28 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import collections
import json
import os
import pickle
from functools import partial
from itertools import islice
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import numpy as np
import transformers
from cehrbert_data.const.common import NA
from datasets import Dataset, DatasetDict
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
Expand All @@ -25,7 +28,7 @@

TOKENIZER_FILE_NAME = "tokenizer.json"
CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json"
LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json"
LAB_STATS_FILE_NAME = "cehrbert_lab_stats.json"


def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
Expand All @@ -50,18 +53,31 @@ def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
raise RuntimeError(f"Can't load the json file at {json_file}") from e


class CehrBertTokenizer(PushToHubMixin):
def create_numeric_concept_unit_mapping(
lab_stats: List[Dict[str, Any]]
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
numeric_concept_unit_mapping = collections.defaultdict(list)
for each_lab_stat in lab_stats:
numeric_concept_unit_mapping[each_lab_stat["concept_id"]].append(
(each_lab_stat["count"], each_lab_stat["unit"])
)

def __init__(
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
):
self._tokenizer = tokenizer
concept_prob_mapping = dict()
concept_unit_mapping = dict()
for concept_id in numeric_concept_unit_mapping.keys():
counts, units = zip(*numeric_concept_unit_mapping[concept_id])
total_count = sum(counts)
probs = [float(c) / total_count for c in counts]
concept_prob_mapping[concept_id] = probs
concept_unit_mapping[concept_id] = units
return concept_prob_mapping, concept_unit_mapping


class NumericEventStatistics:
def __init__(self, lab_stats: List[Dict[str, Any]]):
self._lab_stats = lab_stats
self._lab_stat_mapping = {
lab_stat["concept_id"]: {
self._lab_stats_mapping = {
(lab_stat["concept_id"], lab_stat["unit"]): {
"unit": lab_stat["unit"],
"mean": lab_stat["mean"],
"std": lab_stat["std"],
Expand All @@ -71,6 +87,53 @@ def __init__(
}
for lab_stat in lab_stats
}
self._concept_prob_mapping, self._concept_unit_mapping = create_numeric_concept_unit_mapping(lab_stats)

def get_numeric_concept_ids(self) -> List[str]:
return [_["concept_id"] for _ in self._lab_stats]

def get_random_unit(self, concept_id: str) -> str:
if concept_id in self._concept_prob_mapping:
unit_probs = self._concept_prob_mapping[concept_id]
return np.random.choice(self._concept_unit_mapping[concept_id], p=unit_probs)
return NA

def normalize(self, concept_id: str, unit: str, concept_value: float) -> float:
if (concept_id, unit) in self._lab_stats_mapping:
concept_unit_stats = self._lab_stats_mapping[(concept_id, unit)]
mean_ = concept_value - concept_unit_stats["mean"]
std = concept_unit_stats["std"]
if std > 0:
value_outlier_std = concept_unit_stats["value_outlier_std"]
normalized_value = mean_ / std
# Clip the value between the lower and upper bounds of the corresponding lab
normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value))
else:
# If there is not a valid standard deviation,
# we just the normalized value to the mean of the standard normal
normalized_value = 0.0
return normalized_value
return concept_value

def denormalize(self, concept_id: str, value: float) -> Tuple[float, str]:
unit = self.get_random_unit(concept_id)
if (concept_id, unit) in self._lab_stats_mapping:
stats = self._lab_stats_mapping[(concept_id, unit)]
value = value * stats["std"] + stats["mean"]
return value, unit


class CehrBertTokenizer(PushToHubMixin):

def __init__(
self,
tokenizer: Tokenizer,
lab_stats: List[Dict[str, Any]],
concept_name_mapping: Dict[str, str],
):
self._tokenizer = tokenizer
self._lab_stats = lab_stats
self._numeric_event_statistics = NumericEventStatistics(lab_stats)
self._concept_name_mapping = concept_name_mapping
self._oov_token_index = self._tokenizer.token_to_id(OUT_OF_VOCABULARY_TOKEN)
self._padding_token_index = self._tokenizer.token_to_id(PAD_TOKEN)
Expand Down Expand Up @@ -112,7 +175,13 @@ def lab_token_ids(self):
UNUSED_TOKEN,
OUT_OF_VOCABULARY_TOKEN,
]
return self.encode([_["concept_id"] for _ in self._lab_stats if _["concept_id"] not in reserved_tokens])
return self.encode(
[
concept_id
for concept_id in self._numeric_event_statistics.get_numeric_concept_ids()
if concept_id not in reserved_tokens
]
)

def encode(self, concept_ids: Sequence[str]) -> Sequence[int]:
encoded = self._tokenizer.encode(concept_ids, is_pretokenized=True)
Expand Down Expand Up @@ -351,18 +420,5 @@ def batched_generator():
def batch_concat_concepts(cls, records: Dict[str, List], feature_name) -> Dict[str, List]:
return {feature_name: [" ".join(map(str, _)) for _ in records[feature_name]]}

def normalize(self, concept_id, concept_value) -> float:
if concept_id in self._lab_stat_mapping:
mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"]
std = self._lab_stat_mapping[concept_id]["std"]
if std > 0:
value_outlier_std = self._lab_stat_mapping[concept_id]["value_outlier_std"]
normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"]
# Clip the value between the lower and upper bounds of the corresponding lab
normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value))
else:
# If there is not a valid standard deviation,
# we just the normalized value to the mean of the standard normal
normalized_value = 0.0
return normalized_value
return concept_value
def normalize(self, concept_id: str, unit: str, concept_value: float) -> float:
return self._numeric_event_statistics.normalize(concept_id, unit, concept_value)
4 changes: 3 additions & 1 deletion src/cehrbert/models/hf_models/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import Any, Dict

from cehrbert_data.const.common import NA

from cehrbert.utils.stat_utils import TruncatedOnlineStatistics


Expand All @@ -26,7 +28,7 @@ def map_statistics(batch: Dict[str, Any], capacity=100, value_outlier_std=2.0) -
if "units" in batch:
concept_value_units = batch["units"]
else:
concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]]
concept_value_units = [[NA for _ in cons] for cons in batch["concept_ids"]]
numeric_stats_by_lab = collections.defaultdict(
partial(TruncatedOnlineStatistics, capacity=capacity, value_outlier_std=value_outlier_std)
)
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional

from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.decorators.patient_event_decorator_base import AttType

from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
MedsToBertMimic4,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from datetime import datetime

from cehrbert_data.decorators.patient_event_decorator import AttType
from cehrbert_data.decorators.patient_event_decorator_base import AttType

from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping
from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit
Expand Down
Loading

0 comments on commit 4ba50e2

Please sign in to comment.