From 6a89ae9a33eea9dded720f482ed06086d3f714c3 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 10 Oct 2024 00:27:22 -0400 Subject: [PATCH] completed the meds_to_cehrbert_omop.py logic for converting the OMOP MEDS data to CEHRBERT --- .../__init__.py | 1 + .../meds_to_cehrbert_omop.py | 60 +++++++++++-------- .../runners/hf_runner_argument_dataclass.py | 17 +++--- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/__init__.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/__init__.py index c8396a8b..a991e9c1 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/__init__.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/__init__.py @@ -1,2 +1,3 @@ from .meds_to_cehrbert_base import MedsToCehrBertConversion from .meds_to_cehrbert_micmic4 import MedsToBertMimic4 +from .meds_to_cehrbert_omop import MedsToCehrbertOMOP diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py index 30ba1822..5b7f99de 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py @@ -1,3 +1,4 @@ +from collections import defaultdict from datetime import datetime from typing import List, Tuple @@ -22,22 +23,9 @@ def generate_demographics_and_patient_blocks( gender = None ethnicity = None - current_visit_id = None - current_date = None - events_for_current_date = [] - patient_blocks = [] - + visit_events = defaultdict(list) + unlinked_event_mapping = defaultdict(list) for e in patient.events: - - # Skip out of the loop if the events' time stamps are beyond the prediction time - if prediction_time is not None and e.time is not None: - if e.time > prediction_time: - break - - # Try to set current_visit_id - if not current_visit_id: - current_visit_id = e.visit_id if hasattr(e, "visit_id") else None - # This indicates demographics features if e.code in birth_codes: birth_datetime = e.time @@ -48,19 +36,43 @@ def generate_demographics_and_patient_blocks( elif e.code.upper().startswith("ETHNICITY"): ethnicity = e.code elif e.time is not None: - if not current_date: - current_date = e.time - if current_date.date() == e.time.date(): - events_for_current_date.append(e) + # Skip out of the loop if the events' time stamps are beyond the prediction time + if prediction_time is not None: + if e.time > prediction_time: + break + if hasattr(e, "visit_id"): + visit_id = e.visit_id + visit_events[visit_id].append(e) else: - patient_blocks.append(PatientBlock(events_for_current_date, current_visit_id, self)) - events_for_current_date = [e] - current_date = e.time + unlinked_event_mapping[e.time.strftime("%Y-%m-%d")].append(e) + + patient_block_mapping = { + visit_id: PatientBlock(events=events, visit_id=visit_id, conversion=self) + for visit_id, events in visit_events.items() + } + + # Try to connect the unlinked events to existing visits + for current_date_str in list(unlinked_event_mapping.keys()): + current_date = datetime.strptime(current_date_str, "%Y-%m-%d") + for visit_id, patient_block in patient_block_mapping.items(): + if patient_block.min_time.date() <= current_date <= patient_block.max_time.date(): + patient_block.events.extend(unlinked_event_mapping.pop(current_date_str, [])) + # Need to sort the events if we insert new events to the patient block + patient_block.events = sorted(patient_block.events, key=lambda _: _.time) + break - if events_for_current_date: - patient_blocks.append(PatientBlock(events_for_current_date, current_visit_id, self)) + max_visit_id = max(patient_block_mapping.keys()) + 1 + for events in unlinked_event_mapping.values(): + patient_block_mapping[max_visit_id] = PatientBlock(events, max_visit_id, self) + max_visit_id += 1 + patient_blocks = list(patient_block_mapping.values()) demographics = PatientDemographics(birth_datetime=birth_datetime, race=race, gender=gender, ethnicity=ethnicity) + + # If there are unlinked events, we need to add them as new patient blocks, therefore we need to re-order the patient block + if len(unlinked_event_mapping) > 0: + patient_blocks = sorted(patient_block_mapping.values(), key=lambda block: block.min_time) + return demographics, patient_blocks def _create_ed_admission_matching_rules(self) -> List[str]: diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index ed0b38ca..025d8f84 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -7,6 +7,7 @@ from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( MedsToBertMimic4, MedsToCehrBertConversion, + MedsToCehrbertOMOP, ) # Create an enum dynamically from the list @@ -111,14 +112,14 @@ class DataTrainingArguments: ) # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire # list right now. - meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = ( - dataclasses.field( - default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__], - metadata={ - "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", - "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", - }, - ) + meds_to_cehrbert_conversion_type: Literal[ + MedsToCehrBertConversionType[MedsToBertMimic4.__name__, MedsToCehrbertOMOP.__name__] + ] = dataclasses.field( + default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__], + metadata={ + "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", + "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", + }, ) include_auxiliary_token: Optional[bool] = dataclasses.field( default=False,