Skip to content

Commit

Permalink
Deleted all spark related code and included the new cehrbert_data dep…
Browse files Browse the repository at this point in the history
…endency (#52)

* Deleted all spark related code and included the new cehrbert_data dependency
* Fixed a bug where MedsToCehrBertConversionType is compared to infer the corresponding MedsToCehrBertConversion
  • Loading branch information
ChaoPang authored Sep 8, 2024
1 parent c8410e0 commit e5724f5
Show file tree
Hide file tree
Showing 61 changed files with 69 additions and 6,597 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ dependencies = [
"Pillow==10.3.0",
"pyarrow==15.0.0",
"pydantic==2.6.0",
"pyspark==3.2.2",
"python-dateutil==2.8.2",
"PyYAML==6.0.1",
"scikit-learn==1.4.0",
Expand All @@ -57,7 +56,8 @@ dependencies = [
"transformers==4.39.3",
"Werkzeug==3.0.1",
"wandb==0.17.8",
"xgboost==2.0.3"
"xgboost==2.0.3",
"cehrbert_data==0.0.1"
]

[tool.setuptools_scm]
Expand All @@ -66,8 +66,8 @@ dependencies = [
Homepage = "https://github.com/cumc-dbmi/cehr-bert"

[project.scripts]
cehrbert-pretraining = "cehrbert.runner.hf_cehrbert_pretrain_runner:main"
cehrbert-finetuning = "cehrbert.runner.hf_cehrbert_finetuning_runner:main"
cehrbert-pretraining = "cehrbert.runners.hf_cehrbert_pretrain_runner:main"
cehrbert-finetuning = "cehrbert.runners.hf_cehrbert_finetuning_runner:main"

[project.optional-dependencies]
dev = [
Expand Down
9 changes: 0 additions & 9 deletions src/cehrbert/config/output_names.py

This file was deleted.

Empty file removed src/cehrbert/const/__init__.py
Empty file.
28 changes: 0 additions & 28 deletions src/cehrbert/const/common.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/cehrbert/data_generators/data_generator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import numpy as np
import pandas as pd

from .data_classes import RowSlicer
from .learning_objective import (
from cehrbert.data_generators.data_classes import RowSlicer
from cehrbert.data_generators.learning_objective import (
BertFineTuningLearningObjective,
DemographicsLearningObjective,
HierarchicalArtificialTokenPredictionLearningObjective,
Expand All @@ -24,7 +24,7 @@
TimeAttentionLearningObjective,
VisitPredictionLearningObjective,
)
from .tokenizer import ConceptTokenizer
from cehrbert.data_generators.tokenizer import ConceptTokenizer


def create_indexes_by_time_window(dates, cursor, max_seq_len, time_window_size):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

import numpy as np
import pandas as pd
from cehrbert_data.decorators.patient_event_decorator import get_att_function
from datasets.formatting.formatting import LazyBatch
from dateutil.relativedelta import relativedelta
from meds.schema import birth_code, death_code
from pandas import Series

from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
from cehrbert.spark_apps.decorators.patient_event_decorator import get_att_function

birth_codes = [birth_code, "MEDS_BIRTH"]
death_codes = [death_code, "MEDS_DEATH"]
Expand Down
10 changes: 7 additions & 3 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@


def get_meds_to_cehrbert_conversion_cls(
meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType,
meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str],
) -> MedsToCehrBertConversion:
for cls in MedsToCehrBertConversion.__subclasses__():
if meds_to_cehrbert_conversion_type.name == cls.__name__:
return cls()
if isinstance(meds_to_cehrbert_conversion_type, MedsToCehrBertConversionType):
if meds_to_cehrbert_conversion_type.name == cls.__name__:
return cls()
elif isinstance(meds_to_cehrbert_conversion_type, str):
if meds_to_cehrbert_conversion_type == cls.__name__:
return cls()
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")


Expand Down
3 changes: 1 addition & 2 deletions src/cehrbert/data_generators/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Optional, Sequence, Union

from cehrbert_data.const.common import UNKNOWN_CONCEPT
from dask.dataframe import Series as dd_series
from pandas import Series as df_series
from tensorflow.keras.preprocessing.text import Tokenizer

from ..const.common import UNKNOWN_CONCEPT


class ConceptTokenizer:
unused_token = "[UNUSED]"
Expand Down
7 changes: 5 additions & 2 deletions src/cehrbert/evaluations/model_evaluators/model_evaluators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import copy
import os
import pathlib
from abc import abstractmethod

from ...trainers.model_trainer import AbstractModel
from ...utils.model_utils import os, pathlib, tf
import tensorflow as tf

from cehrbert.trainers.model_trainer import AbstractModel


def get_metrics():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, train_test_split
from tensorflow.python.keras.utils.generic_utils import get_custom_objects

from ...config.grid_search_config import GridSearchConfig
from ...data_generators.learning_objective import post_pad_pre_truncate
from ...models.evaluation_models import create_bi_lstm_model
from ...models.loss_schedulers import CosineLRSchedule
from ...utils.model_utils import compute_binary_metrics, multimode, np, os, pd, pickle, save_training_history, tf
from .model_evaluators import AbstractModelEvaluator, get_metrics
from cehrbert.config.grid_search_config import GridSearchConfig
from cehrbert.data_generators.learning_objective import post_pad_pre_truncate
from cehrbert.evaluations.model_evaluators.model_evaluators import AbstractModelEvaluator, get_metrics
from cehrbert.models.evaluation_models import create_bi_lstm_model
from cehrbert.models.loss_schedulers import CosineLRSchedule
from cehrbert.utils.model_utils import compute_binary_metrics, multimode, np, os, pd, pickle, save_training_history, tf

# Define a list of learning rates to fine-tune the model with
LEARNING_RATES = [0.5e-4, 0.8e-4, 1.0e-4, 1.2e-4]
Expand Down
6 changes: 3 additions & 3 deletions src/cehrbert/models/bert_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import tensorflow as tf

from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding
from ..utils.model_utils import create_concept_mask
from .layers.custom_layers import (
from cehrbert.keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding
from cehrbert.models.layers.custom_layers import (
ConceptValueTransformationLayer,
Encoder,
PositionalEncodingLayer,
TimeEmbeddingLayer,
VisitEmbeddingLayer,
)
from cehrbert.utils.model_utils import create_concept_mask


def transformer_bert_model(
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert/models/bert_models_visit_prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf

from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding
from .layers.custom_layers import (
from cehrbert.keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding
from cehrbert.models.layers.custom_layers import (
ConceptValueTransformationLayer,
DecoderLayer,
Encoder,
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert/models/evaluation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from tensorflow.keras.initializers import Constant
from tensorflow.keras.models import Model

from .bert_models_visit_prediction import transformer_bert_model_visit_prediction
from .layers.custom_layers import ConvolutionBertLayer, get_custom_objects
from cehrbert.models.bert_models_visit_prediction import transformer_bert_model_visit_prediction
from cehrbert.models.layers.custom_layers import ConvolutionBertLayer, get_custom_objects


def create_bi_lstm_model(
Expand Down
5 changes: 3 additions & 2 deletions src/cehrbert/models/hierachical_bert_model_v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .layers.custom_layers import (
import tensorflow as tf

from cehrbert.models.layers.custom_layers import (
ConceptValueTransformationLayer,
Encoder,
ReusableEmbedding,
SimpleDecoderLayer,
TemporalTransformationLayer,
TiedOutputEmbedding,
tf,
)


Expand Down
7 changes: 4 additions & 3 deletions src/cehrbert/models/hierachical_phenotype_model_new.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from .hierachical_bert_model_v2 import create_att_concept_mask
from .layers.custom_layers import (
import tensorflow as tf

from cehrbert.models.hierachical_bert_model_v2 import create_att_concept_mask
from cehrbert.models.layers.custom_layers import (
ConceptValueTransformationLayer,
Encoder,
ReusableEmbedding,
SimpleDecoderLayer,
TemporalTransformationLayer,
TiedOutputEmbedding,
VisitPhenotypeLayer,
tf,
)


Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/models/parse_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from sys import argv

from ..data_generators.graph_sample_method import SimilarityType
from cehrbert.data_generators.graph_sample_method import SimilarityType


def create_parse_args():
Expand Down
19 changes: 11 additions & 8 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional

from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
from cehrbert_data.decorators.patient_event_decorator import AttType

from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import (
MedsToBertMimic4,
MedsToCehrBertConversion,
)
from ..spark_apps.decorators.patient_event_decorator import AttType

# Create an enum dynamically from the list
MedsToCehrBertConversionType = Enum(
Expand Down Expand Up @@ -110,12 +111,14 @@ class DataTrainingArguments:
)
# TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire
# list right now.
meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = dataclasses.field(
default=MedsToBertMimic4,
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = (
dataclasses.field(
default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__],
metadata={
"help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4",
"choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}",
},
)
)
include_auxiliary_token: Optional[bool] = dataclasses.field(
default=False,
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/runner_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import logging

from .hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments

LOG = logging.get_logger("transformers")

Expand Down
Empty file.
Empty file.
44 changes: 0 additions & 44 deletions src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py

This file was deleted.

Loading

0 comments on commit e5724f5

Please sign in to comment.