diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py index b6738b39..e6579b13 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_base.py @@ -3,10 +3,6 @@ from dataclasses import dataclass from typing import List, Optional -import meds_reader - -from cehrbert.data_generators.hf_data_generator.meds_utils import PatientBlock, PatientDemographics - @dataclass class EventConversionRule: diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py index e50ce436..1a94b261 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_omop.py @@ -1,8 +1,8 @@ from typing import List -from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToCehrBertConversion from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import ( EventConversionRule, + MedsToCehrBertConversion, ) diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 356fc59c..64880cce 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -13,9 +13,7 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToCehrBertConversion -from cehrbert.data_generators.hf_data_generator.patient_block import ( - get_func_for_generate_demographics_and_patient_blocks, -) +from cehrbert.data_generators.hf_data_generator.patient_block import generate_demographics_and_patient_blocks from cehrbert.med_extension.schema_extension import CehrBertPatient, Visit from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType diff --git a/src/cehrbert/data_generators/hf_data_generator/patient_block.py b/src/cehrbert/data_generators/hf_data_generator/patient_block.py index a22b03e3..217e6251 100644 --- a/src/cehrbert/data_generators/hf_data_generator/patient_block.py +++ b/src/cehrbert/data_generators/hf_data_generator/patient_block.py @@ -17,7 +17,6 @@ MedsToCehrBertConversion, MedsToCehrbertOMOP, ) -from cehrbert.data_generators.hf_data_generator.meds_utils import PatientBlock, PatientDemographics from cehrbert.med_extension.schema_extension import Event diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 025d8f84..5eb3165e 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -113,7 +113,8 @@ class DataTrainingArguments: # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire # list right now. meds_to_cehrbert_conversion_type: Literal[ - MedsToCehrBertConversionType[MedsToBertMimic4.__name__, MedsToCehrbertOMOP.__name__] + MedsToCehrBertConversionType[MedsToBertMimic4.__name__], + MedsToCehrBertConversionType[MedsToCehrbertOMOP.__name__], ] = dataclasses.field( default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__], metadata={ diff --git a/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py b/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py index 89379ef0..ed7135d6 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/meds_utils_test.py @@ -1,15 +1,23 @@ import unittest -from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import MedsToBertMimic4 +from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( + MedsToBertMimic4, + MedsToCehrbertOMOP, +) from cehrbert.data_generators.hf_data_generator.meds_utils import get_meds_to_cehrbert_conversion_cls from cehrbert.runners.hf_runner_argument_dataclass import AttType, MedsToCehrBertConversionType class TestGetMedsToCehrBertConversionCls(unittest.TestCase): - def test_conversion(self): - conversion_type = MedsToCehrBertConversionType["MedsToBertMimic4"] + def test_meds_to_bert_omop_conversion(self): + conversion_type = MedsToCehrBertConversionType["MedsToCehrbertOMOP"] result = get_meds_to_cehrbert_conversion_cls(conversion_type) + self.assertIsInstance(result, MedsToCehrbertOMOP) + + def test_meds_to_bert_mimic4_conversion(self): + conversion_type = MedsToCehrBertConversionType["MedsToBertMimic4"] + result = get_meds_to_cehrbert_conversion_cls(conversion_type, default_visit_id=1) self.assertIsInstance(result, MedsToBertMimic4) def test_invalid_conversion(self):