From 4ba50e25619ae3c066d3cea47f3f1e2be7072216 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Wed, 9 Oct 2024 13:54:35 -0400 Subject: [PATCH] Added units to meds to cehrbert data transformation and added unit in normalizing numeric values (#60) * added units to data transformation and added unit in normalizing the numeric values * moved the constant NA from tokenization_hf_cehrbert.py to tokenization_utils.py * upgraded cehrbert_data to v0.0.3 --- pyproject.toml | 2 +- .../hf_data_generator/hf_dataset_mapping.py | 13 +- .../hf_data_generator/meds_utils.py | 3 + .../med_extension/schema_extension.py | 1 + .../hf_models/tokenization_hf_cehrbert.py | 112 ++++++++++++---- .../models/hf_models/tokenization_utils.py | 4 +- .../runners/hf_runner_argument_dataclass.py | 2 +- .../hf_med_to_cehrbert_mapping_test.py | 2 +- .../numeric_concept_statistics_test.py | 124 ++++++++++++++++++ 9 files changed, 229 insertions(+), 34 deletions(-) create mode 100644 tests/unit_tests/models/hf_models/numeric_concept_statistics_test.py diff --git a/pyproject.toml b/pyproject.toml index 778671c0..62859ee5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dependencies = [ "Werkzeug==3.0.1", "wandb==0.17.8", "xgboost==2.0.3", - "cehrbert_data==0.0.1" + "cehrbert_data==0.0.3" ] [tool.setuptools_scm] diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index f3377e9c..2e2519ff 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -8,7 +8,8 @@ import numpy as np import pandas as pd -from cehrbert_data.decorators.patient_event_decorator import get_att_function +from cehrbert_data.const.common import NA +from cehrbert_data.decorators.patient_event_decorator_base import get_att_function from datasets.formatting.formatting import LazyBatch from dateutil.relativedelta import relativedelta from meds.schema import birth_code, death_code @@ -137,6 +138,7 @@ def _update_cehrbert_record( concept_value_mask: int = 0, concept_value: float = -1.0, mlm_skip_value: int = 0, + unit: str = NA, ) -> None: cehrbert_record["concept_ids"].append(code) cehrbert_record["visit_concept_orders"].append(visit_concept_order) @@ -146,6 +148,7 @@ def _update_cehrbert_record( cehrbert_record["visit_concept_ids"].append(visit_concept_id) cehrbert_record["concept_value_masks"].append(concept_value_mask) cehrbert_record["concept_values"].append(concept_value) + cehrbert_record["units"].append(unit) cehrbert_record["mlm_skip_values"].append(mlm_skip_value) def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: @@ -160,6 +163,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: "visit_concept_orders": [], "concept_value_masks": [], "concept_values": [], + "units": [], "mlm_skip_values": [], "visit_concept_ids": [], } @@ -274,6 +278,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask numeric_value = e.get("numeric_value", None) + unit = e.get("unit", NA) concept_value_mask = int(numeric_value is not None) concept_value = numeric_value if concept_value_mask == 1 else -1.0 code = replace_escape_chars(e["code"]) @@ -295,6 +300,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: visit_concept_id=visit_type, concept_value_mask=concept_value_mask, concept_value=concept_value, + unit=unit, mlm_skip_value=concept_value_mask, ) @@ -419,22 +425,25 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: # If any concept has a value associated with it, we normalize the value if np.any(np.asarray(concept_value_masks) > 0): + units = record["units"] normalized_concept_values = copy.deepcopy(concept_values) for i, ( concept_id, + unit, token_id, concept_value_mask, concept_value, ) in enumerate( zip( record["concept_ids"], + units, input_ids, concept_value_masks, concept_values, ) ): if token_id in self._lab_token_ids: - normalized_concept_value = self._concept_tokenizer.normalize(concept_id, concept_value) + normalized_concept_value = self._concept_tokenizer.normalize(concept_id, unit, concept_value) normalized_concept_values[i] = normalized_concept_value record["concept_values"] = normalized_concept_values diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index f064fe08..4f975812 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -176,6 +176,7 @@ def _convert_event(self, event) -> List[Event]: time = getattr(event, "time", None) text_value = getattr(event, "text_value", None) numeric_value = getattr(event, "numeric_value", None) + unit = getattr(event, "unit", None) if numeric_value is None and text_value is not None: conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code) @@ -188,6 +189,7 @@ def _convert_event(self, event) -> List[Event]: code=label, time=time, numeric_value=float(value), + unit=unit, properties={"visit_id": self.visit_id, "table": "meds"}, ) for label, value in zip(conversion_rule.mapped_event_labels, match.groups()) @@ -200,6 +202,7 @@ def _convert_event(self, event) -> List[Event]: code=code, time=time, numeric_value=numeric_value, + unit=unit, text_value=text_value, properties={"visit_id": self.visit_id, "table": "meds"}, ) diff --git a/src/cehrbert/med_extension/schema_extension.py b/src/cehrbert/med_extension/schema_extension.py index 13c20c73..081e5345 100644 --- a/src/cehrbert/med_extension/schema_extension.py +++ b/src/cehrbert/med_extension/schema_extension.py @@ -10,6 +10,7 @@ "code": str, "text_value": NotRequired[Optional[str]], "numeric_value": NotRequired[Optional[float]], + "unit": NotRequired[Optional[str]], "datetime_value": NotRequired[datetime.datetime], "properties": NotRequired[Optional[Mapping[str, Any]]], }, diff --git a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py index 27ab92df..2531191a 100644 --- a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py @@ -1,11 +1,14 @@ +import collections import json import os import pickle from functools import partial from itertools import islice -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Sequence, Tuple, Union +import numpy as np import transformers +from cehrbert_data.const.common import NA from datasets import Dataset, DatasetDict from tokenizers import Tokenizer from tokenizers.models import WordLevel @@ -25,7 +28,7 @@ TOKENIZER_FILE_NAME = "tokenizer.json" CONCEPT_MAPPING_FILE_NAME = "concept_name_mapping.json" -LAB_STATS_FILE_NAME = "cehrgpt_lab_stats.json" +LAB_STATS_FILE_NAME = "cehrbert_lab_stats.json" def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]: @@ -50,18 +53,31 @@ def load_json_file(json_file) -> Union[List[Dict[str, Any]], Dict[str, Any]]: raise RuntimeError(f"Can't load the json file at {json_file}") from e -class CehrBertTokenizer(PushToHubMixin): +def create_numeric_concept_unit_mapping( + lab_stats: List[Dict[str, Any]] +) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]: + numeric_concept_unit_mapping = collections.defaultdict(list) + for each_lab_stat in lab_stats: + numeric_concept_unit_mapping[each_lab_stat["concept_id"]].append( + (each_lab_stat["count"], each_lab_stat["unit"]) + ) - def __init__( - self, - tokenizer: Tokenizer, - lab_stats: List[Dict[str, Any]], - concept_name_mapping: Dict[str, str], - ): - self._tokenizer = tokenizer + concept_prob_mapping = dict() + concept_unit_mapping = dict() + for concept_id in numeric_concept_unit_mapping.keys(): + counts, units = zip(*numeric_concept_unit_mapping[concept_id]) + total_count = sum(counts) + probs = [float(c) / total_count for c in counts] + concept_prob_mapping[concept_id] = probs + concept_unit_mapping[concept_id] = units + return concept_prob_mapping, concept_unit_mapping + + +class NumericEventStatistics: + def __init__(self, lab_stats: List[Dict[str, Any]]): self._lab_stats = lab_stats - self._lab_stat_mapping = { - lab_stat["concept_id"]: { + self._lab_stats_mapping = { + (lab_stat["concept_id"], lab_stat["unit"]): { "unit": lab_stat["unit"], "mean": lab_stat["mean"], "std": lab_stat["std"], @@ -71,6 +87,53 @@ def __init__( } for lab_stat in lab_stats } + self._concept_prob_mapping, self._concept_unit_mapping = create_numeric_concept_unit_mapping(lab_stats) + + def get_numeric_concept_ids(self) -> List[str]: + return [_["concept_id"] for _ in self._lab_stats] + + def get_random_unit(self, concept_id: str) -> str: + if concept_id in self._concept_prob_mapping: + unit_probs = self._concept_prob_mapping[concept_id] + return np.random.choice(self._concept_unit_mapping[concept_id], p=unit_probs) + return NA + + def normalize(self, concept_id: str, unit: str, concept_value: float) -> float: + if (concept_id, unit) in self._lab_stats_mapping: + concept_unit_stats = self._lab_stats_mapping[(concept_id, unit)] + mean_ = concept_value - concept_unit_stats["mean"] + std = concept_unit_stats["std"] + if std > 0: + value_outlier_std = concept_unit_stats["value_outlier_std"] + normalized_value = mean_ / std + # Clip the value between the lower and upper bounds of the corresponding lab + normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value)) + else: + # If there is not a valid standard deviation, + # we just the normalized value to the mean of the standard normal + normalized_value = 0.0 + return normalized_value + return concept_value + + def denormalize(self, concept_id: str, value: float) -> Tuple[float, str]: + unit = self.get_random_unit(concept_id) + if (concept_id, unit) in self._lab_stats_mapping: + stats = self._lab_stats_mapping[(concept_id, unit)] + value = value * stats["std"] + stats["mean"] + return value, unit + + +class CehrBertTokenizer(PushToHubMixin): + + def __init__( + self, + tokenizer: Tokenizer, + lab_stats: List[Dict[str, Any]], + concept_name_mapping: Dict[str, str], + ): + self._tokenizer = tokenizer + self._lab_stats = lab_stats + self._numeric_event_statistics = NumericEventStatistics(lab_stats) self._concept_name_mapping = concept_name_mapping self._oov_token_index = self._tokenizer.token_to_id(OUT_OF_VOCABULARY_TOKEN) self._padding_token_index = self._tokenizer.token_to_id(PAD_TOKEN) @@ -112,7 +175,13 @@ def lab_token_ids(self): UNUSED_TOKEN, OUT_OF_VOCABULARY_TOKEN, ] - return self.encode([_["concept_id"] for _ in self._lab_stats if _["concept_id"] not in reserved_tokens]) + return self.encode( + [ + concept_id + for concept_id in self._numeric_event_statistics.get_numeric_concept_ids() + if concept_id not in reserved_tokens + ] + ) def encode(self, concept_ids: Sequence[str]) -> Sequence[int]: encoded = self._tokenizer.encode(concept_ids, is_pretokenized=True) @@ -351,18 +420,5 @@ def batched_generator(): def batch_concat_concepts(cls, records: Dict[str, List], feature_name) -> Dict[str, List]: return {feature_name: [" ".join(map(str, _)) for _ in records[feature_name]]} - def normalize(self, concept_id, concept_value) -> float: - if concept_id in self._lab_stat_mapping: - mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"] - std = self._lab_stat_mapping[concept_id]["std"] - if std > 0: - value_outlier_std = self._lab_stat_mapping[concept_id]["value_outlier_std"] - normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"] - # Clip the value between the lower and upper bounds of the corresponding lab - normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value)) - else: - # If there is not a valid standard deviation, - # we just the normalized value to the mean of the standard normal - normalized_value = 0.0 - return normalized_value - return concept_value + def normalize(self, concept_id: str, unit: str, concept_value: float) -> float: + return self._numeric_event_statistics.normalize(concept_id, unit, concept_value) diff --git a/src/cehrbert/models/hf_models/tokenization_utils.py b/src/cehrbert/models/hf_models/tokenization_utils.py index ee2515c0..72495a99 100644 --- a/src/cehrbert/models/hf_models/tokenization_utils.py +++ b/src/cehrbert/models/hf_models/tokenization_utils.py @@ -4,6 +4,8 @@ from functools import partial from typing import Any, Dict +from cehrbert_data.const.common import NA + from cehrbert.utils.stat_utils import TruncatedOnlineStatistics @@ -26,7 +28,7 @@ def map_statistics(batch: Dict[str, Any], capacity=100, value_outlier_std=2.0) - if "units" in batch: concept_value_units = batch["units"] else: - concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]] + concept_value_units = [[NA for _ in cons] for cons in batch["concept_ids"]] numeric_stats_by_lab = collections.defaultdict( partial(TruncatedOnlineStatistics, capacity=capacity, value_outlier_std=value_outlier_std) ) diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 1cef3b56..ed0b38ca 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from cehrbert_data.decorators.patient_event_decorator import AttType +from cehrbert_data.decorators.patient_event_decorator_base import AttType from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( MedsToBertMimic4, diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py index 1f0af1f3..9a3a0164 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py @@ -1,7 +1,7 @@ import unittest from datetime import datetime -from cehrbert_data.decorators.patient_event_decorator import AttType +from cehrbert_data.decorators.patient_event_decorator_base import AttType from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit diff --git a/tests/unit_tests/models/hf_models/numeric_concept_statistics_test.py b/tests/unit_tests/models/hf_models/numeric_concept_statistics_test.py new file mode 100644 index 00000000..84dcc772 --- /dev/null +++ b/tests/unit_tests/models/hf_models/numeric_concept_statistics_test.py @@ -0,0 +1,124 @@ +import unittest + +import numpy as np + +from cehrbert.models.hf_models.tokenization_hf_cehrbert import ( + NA, + NumericEventStatistics, + create_numeric_concept_unit_mapping, +) + + +class TestNumericEventStatistics(unittest.TestCase): + def setUp(self): + # Mock lab stats data + self.lab_stats = [ + { + "concept_id": "concept_1", + "unit": "unit_1", + "mean": 10.0, + "std": 2.0, + "value_outlier_std": 3.0, + "lower_bound": 7.0, + "upper_bound": 13.0, + "count": 100, + }, + { + "concept_id": "concept_2", + "unit": "unit_1", + "mean": 20.0, + "std": 5.0, + "value_outlier_std": 2.5, + "lower_bound": 10.0, + "upper_bound": 30.0, + "count": 100, + }, + { + "concept_id": "concept_2", + "unit": "unit_2", + "mean": 15.0, + "std": 3.0, + "value_outlier_std": 2.0, + "lower_bound": 8.0, + "upper_bound": 22.0, + "count": 200, + }, + ] + # Create an instance of NumericEventStatistics + self.numeric_event_statistics = NumericEventStatistics(self.lab_stats) + + def test_create_numeric_concept_unit_mapping(self): + # Call the function + concept_prob_mapping, concept_unit_mapping = create_numeric_concept_unit_mapping(self.lab_stats) + + # Check the concept_prob_mapping + # For concept_1: Only one unit, so probability = 1.0 + self.assertEqual(concept_prob_mapping["concept_1"], [1.0]) + + # For concept_2: Two units, unit_1 with count 100, unit_2 with count 200 + total_count_concept_2 = 100 + 200 + expected_probs_concept_2 = [ + 100 / total_count_concept_2, + 200 / total_count_concept_2, + ] + self.assertEqual(concept_prob_mapping["concept_2"], expected_probs_concept_2) + + # Check the concept_unit_mapping + self.assertEqual(concept_unit_mapping["concept_1"], ("unit_1",)) + self.assertEqual(concept_unit_mapping["concept_2"], ("unit_1", "unit_2")) + + def test_get_numeric_concept_ids(self): + # Test for correct concept IDs + expected_concept_ids = ["concept_1", "concept_2", "concept_2"] + result = self.numeric_event_statistics.get_numeric_concept_ids() + self.assertEqual(result, expected_concept_ids) + + def test_get_random_unit(self): + # Test get_random_unit method for concept_1 (single unit) + unit = self.numeric_event_statistics.get_random_unit("concept_1") + self.assertEqual(unit, "unit_1") + + # Test get_random_unit method for concept_2 (multiple units) + unit = self.numeric_event_statistics.get_random_unit("concept_2") + self.assertIn(unit, ["unit_1", "unit_2"]) + + # Test get_random_unit method for non-existent concept_3 + unit = self.numeric_event_statistics.get_random_unit("concept_3") + self.assertEqual(unit, NA) + + def test_normalize(self): + # Test normalization for concept_1, unit_1 + normalized_value = self.numeric_event_statistics.normalize("concept_1", "unit_1", 12.0) + # (12 - 10) / 2 = 1.0 + self.assertEqual(normalized_value, 1.0) + + # Test normalization with value beyond outlier bound + normalized_value = self.numeric_event_statistics.normalize("concept_1", "unit_1", 20.0) + # Since (20 - 10) / 2 = 5.0, but it's clipped at value_outlier_std = 3.0 + self.assertEqual(normalized_value, 3.0) + + # Test normalization with a concept_id/unit pair not found + normalized_value = self.numeric_event_statistics.normalize("concept_3", "unit_1", 15.0) + # Since the concept_id/unit doesn't exist, it should return the original value + self.assertEqual(normalized_value, 15.0) + + def test_denormalize(self): + # Mock np.random.choice to return a predictable unit + np.random.choice = lambda *args, **kwargs: "unit_1" + + # Test denormalization for concept_1, unit_1 + denormalized_value, unit = self.numeric_event_statistics.denormalize("concept_1", 1.0) + # value = 1.0 * 2.0 + 10.0 = 12.0 + self.assertEqual(denormalized_value, 12.0) + self.assertEqual(unit, "unit_1") + + # Test denormalization for concept_2, unit_1 + denormalized_value, unit = self.numeric_event_statistics.denormalize("concept_2", 0.5) + # value = 0.5 * 5.0 + 20.0 = 22.5 + self.assertEqual(denormalized_value, 22.5) + self.assertEqual(unit, "unit_1") + + +# Entry point for running the tests +if __name__ == "__main__": + unittest.main()