Skip to content

Commit

Permalink
completed the meds_to_cehrbert_omop.py logic for converting the OMOP …
Browse files Browse the repository at this point in the history
…MEDS data to CEHRBERT
  • Loading branch information
ChaoPang committed Oct 10, 2024
1 parent 82761e8 commit 6a89ae9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .meds_to_cehrbert_base import MedsToCehrBertConversion
from .meds_to_cehrbert_micmic4 import MedsToBertMimic4
from .meds_to_cehrbert_omop import MedsToCehrbertOMOP
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from datetime import datetime
from typing import List, Tuple

Expand All @@ -22,22 +23,9 @@ def generate_demographics_and_patient_blocks(
gender = None
ethnicity = None

current_visit_id = None
current_date = None
events_for_current_date = []
patient_blocks = []

visit_events = defaultdict(list)
unlinked_event_mapping = defaultdict(list)
for e in patient.events:

# Skip out of the loop if the events' time stamps are beyond the prediction time
if prediction_time is not None and e.time is not None:
if e.time > prediction_time:
break

# Try to set current_visit_id
if not current_visit_id:
current_visit_id = e.visit_id if hasattr(e, "visit_id") else None

# This indicates demographics features
if e.code in birth_codes:
birth_datetime = e.time
Expand All @@ -48,19 +36,43 @@ def generate_demographics_and_patient_blocks(
elif e.code.upper().startswith("ETHNICITY"):
ethnicity = e.code
elif e.time is not None:
if not current_date:
current_date = e.time
if current_date.date() == e.time.date():
events_for_current_date.append(e)
# Skip out of the loop if the events' time stamps are beyond the prediction time
if prediction_time is not None:
if e.time > prediction_time:
break
if hasattr(e, "visit_id"):
visit_id = e.visit_id
visit_events[visit_id].append(e)
else:
patient_blocks.append(PatientBlock(events_for_current_date, current_visit_id, self))
events_for_current_date = [e]
current_date = e.time
unlinked_event_mapping[e.time.strftime("%Y-%m-%d")].append(e)

patient_block_mapping = {
visit_id: PatientBlock(events=events, visit_id=visit_id, conversion=self)
for visit_id, events in visit_events.items()
}

# Try to connect the unlinked events to existing visits
for current_date_str in list(unlinked_event_mapping.keys()):
current_date = datetime.strptime(current_date_str, "%Y-%m-%d")
for visit_id, patient_block in patient_block_mapping.items():
if patient_block.min_time.date() <= current_date <= patient_block.max_time.date():
patient_block.events.extend(unlinked_event_mapping.pop(current_date_str, []))
# Need to sort the events if we insert new events to the patient block
patient_block.events = sorted(patient_block.events, key=lambda _: _.time)
break

if events_for_current_date:
patient_blocks.append(PatientBlock(events_for_current_date, current_visit_id, self))
max_visit_id = max(patient_block_mapping.keys()) + 1
for events in unlinked_event_mapping.values():
patient_block_mapping[max_visit_id] = PatientBlock(events, max_visit_id, self)
max_visit_id += 1

patient_blocks = list(patient_block_mapping.values())
demographics = PatientDemographics(birth_datetime=birth_datetime, race=race, gender=gender, ethnicity=ethnicity)

# If there are unlinked events, we need to add them as new patient blocks, therefore we need to re-order the patient block
if len(unlinked_event_mapping) > 0:
patient_blocks = sorted(patient_block_mapping.values(), key=lambda block: block.min_time)

return demographics, patient_blocks

def _create_ed_admission_matching_rules(self) -> List[str]:
Expand Down
17 changes: 9 additions & 8 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
MedsToBertMimic4,
MedsToCehrBertConversion,
MedsToCehrbertOMOP,
)

# Create an enum dynamically from the list
Expand Down Expand Up @@ -111,14 +112,14 @@ class DataTrainingArguments:
)
# TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire
# list right now.
meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = (
dataclasses.field(
default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__],
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
)
meds_to_cehrbert_conversion_type: Literal[
MedsToCehrBertConversionType[MedsToBertMimic4.__name__, MedsToCehrbertOMOP.__name__]
] = dataclasses.field(
default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__],
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
)
include_auxiliary_token: Optional[bool] = dataclasses.field(
default=False,
Expand Down

0 comments on commit 6a89ae9

Please sign in to comment.