Skip to content

Commit

Permalink
Fixed a bug where MedsToCehrBertConversionType is compared to infer t…
Browse files Browse the repository at this point in the history
…he corresponding MedsToCehrBertConversion
  • Loading branch information
ChaoPang committed Sep 8, 2024
1 parent c08d70e commit b0fa68c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
10 changes: 7 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@


def get_meds_to_cehrbert_conversion_cls(
meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType,
meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str],
) -> MedsToCehrBertConversion:
for cls in MedsToCehrBertConversion.__subclasses__():
if meds_to_cehrbert_conversion_type.name == cls.__name__:
return cls()
if isinstance(meds_to_cehrbert_conversion_type, MedsToCehrBertConversionType):
if meds_to_cehrbert_conversion_type.name == cls.__name__:
return cls()
elif isinstance(meds_to_cehrbert_conversion_type, str):
if meds_to_cehrbert_conversion_type == cls.__name__:
return cls()
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")


Expand Down
16 changes: 9 additions & 7 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from cehrbert_data.decorators.patient_event_decorator import AttType

from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
MedsToBertMimic4,
MedsToCehrBertConversion,
)
Expand Down Expand Up @@ -111,12 +111,14 @@ class DataTrainingArguments:
)
# TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire
# list right now.
meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = dataclasses.field(
default=MedsToBertMimic4,
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = (
dataclasses.field(
default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__],
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
)
)
include_auxiliary_token: Optional[bool] = dataclasses.field(
default=False,
Expand Down

0 comments on commit b0fa68c

Please sign in to comment.