diff --git a/pyproject.toml b/pyproject.toml index e687ea4b..42454c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ dependencies = [ "Pillow==10.3.0", "pyarrow==15.0.0", "pydantic==2.6.0", - "pyspark==3.2.2", "python-dateutil==2.8.2", "PyYAML==6.0.1", "scikit-learn==1.4.0", @@ -57,7 +56,8 @@ dependencies = [ "transformers==4.39.3", "Werkzeug==3.0.1", "wandb==0.17.8", - "xgboost==2.0.3" + "xgboost==2.0.3", + "cehrbert_data==0.0.1" ] [tool.setuptools_scm] @@ -66,8 +66,8 @@ dependencies = [ Homepage = "https://github.com/cumc-dbmi/cehr-bert" [project.scripts] -cehrbert-pretraining = "cehrbert.runner.hf_cehrbert_pretrain_runner:main" -cehrbert-finetuning = "cehrbert.runner.hf_cehrbert_finetuning_runner:main" +cehrbert-pretraining = "cehrbert.runners.hf_cehrbert_pretrain_runner:main" +cehrbert-finetuning = "cehrbert.runners.hf_cehrbert_finetuning_runner:main" [project.optional-dependencies] dev = [ diff --git a/src/cehrbert/config/output_names.py b/src/cehrbert/config/output_names.py deleted file mode 100644 index fe663dbe..00000000 --- a/src/cehrbert/config/output_names.py +++ /dev/null @@ -1,9 +0,0 @@ -PARQUET_DATA_PATH = "patient_sequence" -QUALIFIED_CONCEPT_LIST_PATH = "qualified_concept_list" -TIME_ATTENTION_MODEL_PATH = "time_aware_model.h5" -BERT_MODEL_VALIDATION_PATH = "bert_model.h5" -MORTALITY_DATA_PATH = "mortality" -HEART_FAILURE_DATA_PATH = "heart_failure" -HOSPITALIZATION_DATA_PATH = "hospitalization" -INFORMATION_CONTENT_DATA_PATH = "information_content" -CONCEPT_SIMILARITY_PATH = "concept_similarity" diff --git a/src/cehrbert/const/__init__.py b/src/cehrbert/const/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/const/common.py b/src/cehrbert/const/common.py deleted file mode 100644 index 2f9d12ec..00000000 --- a/src/cehrbert/const/common.py +++ /dev/null @@ -1,28 +0,0 @@ -PERSON = "person" -VISIT_OCCURRENCE = "visit_occurrence" -CONDITION_OCCURRENCE = "condition_occurrence" -PROCEDURE_OCCURRENCE = "procedure_occurrence" -DRUG_EXPOSURE = "drug_exposure" -DEVICE_EXPOSURE = "device_exposure" -OBSERVATION = "observation" -MEASUREMENT = "measurement" -CATEGORICAL_MEASUREMENT = "categorical_measurement" -OBSERVATION_PERIOD = "observation_period" -DEATH = "death" -CDM_TABLES = [ - PERSON, - VISIT_OCCURRENCE, - CONDITION_OCCURRENCE, - PROCEDURE_OCCURRENCE, - DRUG_EXPOSURE, - DEVICE_EXPOSURE, - OBSERVATION, - MEASUREMENT, - CATEGORICAL_MEASUREMENT, - OBSERVATION_PERIOD, - DEATH, -] -REQUIRED_MEASUREMENT = "required_measurement" -UNKNOWN_CONCEPT = "[UNKNOWN]" -CONCEPT = "concept" -CONCEPT_ANCESTOR = "concept_ancestor" diff --git a/src/cehrbert/data_generators/data_generator_base.py b/src/cehrbert/data_generators/data_generator_base.py index c91ea86e..27af9c6e 100644 --- a/src/cehrbert/data_generators/data_generator_base.py +++ b/src/cehrbert/data_generators/data_generator_base.py @@ -9,8 +9,8 @@ import numpy as np import pandas as pd -from .data_classes import RowSlicer -from .learning_objective import ( +from cehrbert.data_generators.data_classes import RowSlicer +from cehrbert.data_generators.learning_objective import ( BertFineTuningLearningObjective, DemographicsLearningObjective, HierarchicalArtificialTokenPredictionLearningObjective, @@ -24,7 +24,7 @@ TimeAttentionLearningObjective, VisitPredictionLearningObjective, ) -from .tokenizer import ConceptTokenizer +from cehrbert.data_generators.tokenizer import ConceptTokenizer def create_indexes_by_time_window(dates, cursor, max_seq_len, time_window_size): diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index fd509025..2a413bb6 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from cehrbert_data.decorators.patient_event_decorator import get_att_function from datasets.formatting.formatting import LazyBatch from dateutil.relativedelta import relativedelta from meds.schema import birth_code, death_code @@ -15,7 +16,6 @@ from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments -from cehrbert.spark_apps.decorators.patient_event_decorator import get_att_function birth_codes = [birth_code, "MEDS_BIRTH"] death_codes = [death_code, "MEDS_DEATH"] diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 0fe9ecc3..ac2654e2 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -31,11 +31,15 @@ def get_meds_to_cehrbert_conversion_cls( - meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType, + meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str], ) -> MedsToCehrBertConversion: for cls in MedsToCehrBertConversion.__subclasses__(): - if meds_to_cehrbert_conversion_type.name == cls.__name__: - return cls() + if isinstance(meds_to_cehrbert_conversion_type, MedsToCehrBertConversionType): + if meds_to_cehrbert_conversion_type.name == cls.__name__: + return cls() + elif isinstance(meds_to_cehrbert_conversion_type, str): + if meds_to_cehrbert_conversion_type == cls.__name__: + return cls() raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType") diff --git a/src/cehrbert/data_generators/tokenizer.py b/src/cehrbert/data_generators/tokenizer.py index 487b1b75..19cb6919 100644 --- a/src/cehrbert/data_generators/tokenizer.py +++ b/src/cehrbert/data_generators/tokenizer.py @@ -1,11 +1,10 @@ from typing import Optional, Sequence, Union +from cehrbert_data.const.common import UNKNOWN_CONCEPT from dask.dataframe import Series as dd_series from pandas import Series as df_series from tensorflow.keras.preprocessing.text import Tokenizer -from ..const.common import UNKNOWN_CONCEPT - class ConceptTokenizer: unused_token = "[UNUSED]" diff --git a/src/cehrbert/evaluations/model_evaluators/model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/model_evaluators.py index e9260c39..6c7758ac 100644 --- a/src/cehrbert/evaluations/model_evaluators/model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/model_evaluators.py @@ -1,8 +1,11 @@ import copy +import os +import pathlib from abc import abstractmethod -from ...trainers.model_trainer import AbstractModel -from ...utils.model_utils import os, pathlib, tf +import tensorflow as tf + +from cehrbert.trainers.model_trainer import AbstractModel def get_metrics(): diff --git a/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py b/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py index c6aa9967..6eae8c30 100644 --- a/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py +++ b/src/cehrbert/evaluations/model_evaluators/sequence_model_evaluators.py @@ -6,12 +6,12 @@ from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, train_test_split from tensorflow.python.keras.utils.generic_utils import get_custom_objects -from ...config.grid_search_config import GridSearchConfig -from ...data_generators.learning_objective import post_pad_pre_truncate -from ...models.evaluation_models import create_bi_lstm_model -from ...models.loss_schedulers import CosineLRSchedule -from ...utils.model_utils import compute_binary_metrics, multimode, np, os, pd, pickle, save_training_history, tf -from .model_evaluators import AbstractModelEvaluator, get_metrics +from cehrbert.config.grid_search_config import GridSearchConfig +from cehrbert.data_generators.learning_objective import post_pad_pre_truncate +from cehrbert.evaluations.model_evaluators.model_evaluators import AbstractModelEvaluator, get_metrics +from cehrbert.models.evaluation_models import create_bi_lstm_model +from cehrbert.models.loss_schedulers import CosineLRSchedule +from cehrbert.utils.model_utils import compute_binary_metrics, multimode, np, os, pd, pickle, save_training_history, tf # Define a list of learning rates to fine-tune the model with LEARNING_RATES = [0.5e-4, 0.8e-4, 1.0e-4, 1.2e-4] diff --git a/src/cehrbert/models/bert_models.py b/src/cehrbert/models/bert_models.py index 2df4e126..75620e45 100644 --- a/src/cehrbert/models/bert_models.py +++ b/src/cehrbert/models/bert_models.py @@ -1,14 +1,14 @@ import tensorflow as tf -from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding -from ..utils.model_utils import create_concept_mask -from .layers.custom_layers import ( +from cehrbert.keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding +from cehrbert.models.layers.custom_layers import ( ConceptValueTransformationLayer, Encoder, PositionalEncodingLayer, TimeEmbeddingLayer, VisitEmbeddingLayer, ) +from cehrbert.utils.model_utils import create_concept_mask def transformer_bert_model( diff --git a/src/cehrbert/models/bert_models_visit_prediction.py b/src/cehrbert/models/bert_models_visit_prediction.py index be252ad8..82d7804f 100644 --- a/src/cehrbert/models/bert_models_visit_prediction.py +++ b/src/cehrbert/models/bert_models_visit_prediction.py @@ -1,7 +1,7 @@ import tensorflow as tf -from ..keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding -from .layers.custom_layers import ( +from cehrbert.keras_transformer.extras import ReusableEmbedding, TiedOutputEmbedding +from cehrbert.models.layers.custom_layers import ( ConceptValueTransformationLayer, DecoderLayer, Encoder, diff --git a/src/cehrbert/models/evaluation_models.py b/src/cehrbert/models/evaluation_models.py index 897aac97..18f230c2 100644 --- a/src/cehrbert/models/evaluation_models.py +++ b/src/cehrbert/models/evaluation_models.py @@ -2,8 +2,8 @@ from tensorflow.keras.initializers import Constant from tensorflow.keras.models import Model -from .bert_models_visit_prediction import transformer_bert_model_visit_prediction -from .layers.custom_layers import ConvolutionBertLayer, get_custom_objects +from cehrbert.models.bert_models_visit_prediction import transformer_bert_model_visit_prediction +from cehrbert.models.layers.custom_layers import ConvolutionBertLayer, get_custom_objects def create_bi_lstm_model( diff --git a/src/cehrbert/models/hierachical_bert_model_v2.py b/src/cehrbert/models/hierachical_bert_model_v2.py index 9fd87f3a..ee17e53a 100644 --- a/src/cehrbert/models/hierachical_bert_model_v2.py +++ b/src/cehrbert/models/hierachical_bert_model_v2.py @@ -1,11 +1,12 @@ -from .layers.custom_layers import ( +import tensorflow as tf + +from cehrbert.models.layers.custom_layers import ( ConceptValueTransformationLayer, Encoder, ReusableEmbedding, SimpleDecoderLayer, TemporalTransformationLayer, TiedOutputEmbedding, - tf, ) diff --git a/src/cehrbert/models/hierachical_phenotype_model_new.py b/src/cehrbert/models/hierachical_phenotype_model_new.py index df73b197..585dca79 100644 --- a/src/cehrbert/models/hierachical_phenotype_model_new.py +++ b/src/cehrbert/models/hierachical_phenotype_model_new.py @@ -1,5 +1,7 @@ -from .hierachical_bert_model_v2 import create_att_concept_mask -from .layers.custom_layers import ( +import tensorflow as tf + +from cehrbert.models.hierachical_bert_model_v2 import create_att_concept_mask +from cehrbert.models.layers.custom_layers import ( ConceptValueTransformationLayer, Encoder, ReusableEmbedding, @@ -7,7 +9,6 @@ TemporalTransformationLayer, TiedOutputEmbedding, VisitPhenotypeLayer, - tf, ) diff --git a/src/cehrbert/models/parse_args.py b/src/cehrbert/models/parse_args.py index 64ad85d7..130650c9 100644 --- a/src/cehrbert/models/parse_args.py +++ b/src/cehrbert/models/parse_args.py @@ -1,7 +1,7 @@ import argparse from sys import argv -from ..data_generators.graph_sample_method import SimilarityType +from cehrbert.data_generators.graph_sample_method import SimilarityType def create_parse_args(): diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index b78c6da9..055fe411 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -2,11 +2,12 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( +from cehrbert_data.decorators.patient_event_decorator import AttType + +from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( MedsToBertMimic4, MedsToCehrBertConversion, ) -from ..spark_apps.decorators.patient_event_decorator import AttType # Create an enum dynamically from the list MedsToCehrBertConversionType = Enum( @@ -110,12 +111,14 @@ class DataTrainingArguments: ) # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire # list right now. - meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = dataclasses.field( - default=MedsToBertMimic4, - metadata={ - "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", - "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", - }, + meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = ( + dataclasses.field( + default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__], + metadata={ + "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", + "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", + }, + ) ) include_auxiliary_token: Optional[bool] = dataclasses.field( default=False, diff --git a/src/cehrbert/runners/runner_util.py b/src/cehrbert/runners/runner_util.py index 63876ca9..e6cf0f5e 100644 --- a/src/cehrbert/runners/runner_util.py +++ b/src/cehrbert/runners/runner_util.py @@ -13,7 +13,7 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import logging -from .hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments +from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments LOG = logging.get_logger("transformers") diff --git a/src/cehrbert/spark_apps/__init__.py b/src/cehrbert/spark_apps/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/spark_apps/cohorts/__init__.py b/src/cehrbert/spark_apps/cohorts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py b/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py deleted file mode 100644 index 76395f2c..00000000 --- a/src/cehrbert/spark_apps/cohorts/atrial_fibrillation.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec - -COHORT_QUERY_TEMPLATE = """ -SELECT - co.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id -FROM global_temp.condition_occurrence AS co -JOIN global_temp.visit_occurrence AS vo - ON co.visit_occurrence_id = vo.visit_occurrence_id -JOIN global_temp.{atrial_fibrillation_concepts} AS c - ON co.condition_concept_id = c.concept_id -""" - -ATRIAL_FIBRILLATION_CONCEPT_ID = [313217] - -DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] - -DEFAULT_COHORT_NAME = "atrial_fibrillation" -ATRIAL_FIBRILLATION_CONCEPTS = "atrial_fibrillation_concepts" - - -def query_builder(): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={"atrial_fibrillation_concepts": ATRIAL_FIBRILLATION_CONCEPTS}, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=ATRIAL_FIBRILLATION_CONCEPTS, - ancestor_concept_ids=ATRIAL_FIBRILLATION_CONCEPT_ID, - is_standard=True, - ) - ] - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/cabg.py b/src/cehrbert/spark_apps/cohorts/cabg.py deleted file mode 100644 index 8f4c7051..00000000 --- a/src/cehrbert/spark_apps/cohorts/cabg.py +++ /dev/null @@ -1,71 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec - -COHORT_QUERY_TEMPLATE = """ -SELECT DISTINCT - c.person_id, - c.index_date, - c.visit_occurrence_id -FROM -( - SELECT DISTINCT - vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY po.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY po.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.procedure_occurrence AS po - JOIN global_temp.visit_occurrence AS vo - ON po.visit_occurrence_id = vo.visit_occurrence_id - WHERE EXISTS ( - SELECT 1 - FROM global_temp.{cabg_concept_table} AS ie - WHERE po.procedure_concept_id = ie.concept_id - ) -) c -WHERE c.index_date >= '{date_lower_bound}' -""" - -DEFAULT_COHORT_NAME = "cabg" -DEPENDENCY_LIST = ["person", "procedure_occurrence", "visit_occurrence"] -CABG_INCLUSION_TABLE = "CABG" -CABG_CONCEPTS = [ - 43528001, - 43528003, - 43528004, - 43528002, - 4305852, - 4168831, - 2107250, - 2107216, - 2107222, - 2107231, - 4336464, - 4231998, - 4284104, - 2100873, -] - - -def query_builder(spark_args): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={ - "cabg_concept_table": CABG_INCLUSION_TABLE, - "date_lower_bound": spark_args.date_lower_bound, - }, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=CABG_INCLUSION_TABLE, - ancestor_concept_ids=CABG_CONCEPTS, - is_standard=True, - ) - ] - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py b/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py deleted file mode 100644 index a8f7539f..00000000 --- a/src/cehrbert/spark_apps/cohorts/coronary_artery_disease.py +++ /dev/null @@ -1,88 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec - -COHORT_QUERY_TEMPLATE = """ -WITH prior_graft_stent AS ( - SELECT - po.person_id, - po.procedure_date - FROM global_temp.procedure_occurrence AS po - WHERE EXISTS ( - SELECT 1 - FROM global_temp.{graft_stent_table} AS gs - WHERE po.procedure_concept_id = gs.concept_id - ) -) -SELECT DISTINCT - c.person_id, - c.index_date, - c.visit_occurrence_id -FROM -( - SELECT DISTINCT - vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.condition_occurrence AS co - JOIN global_temp.visit_occurrence AS vo - ON co.visit_occurrence_id = vo.visit_occurrence_id - WHERE EXISTS ( - SELECT 1 - FROM global_temp.{cad_concept_table} AS ie - WHERE co.condition_concept_id = ie.concept_id - ) -) c -WHERE NOT EXISTS ( - -- The patients who had a graft or stent procedures before the index date - -- need to be removed from the cohort - SELECT 1 - FROM prior_graft_stent AS exclusion - WHERE exclusion.person_id = c.person_id - AND c.index_date > exclusion.procedure_date -) AND c.index_date >= '{date_lower_bound}' -""" - -DEFAULT_COHORT_NAME = "coronary_artery_disease" -DEPENDENCY_LIST = [ - "person", - "condition_occurrence", - "procedure_occurrence", - "visit_occurrence", -] -CAD_INCLUSION_TABLE = "CAD" -CAD_CONCEPTS = [317576] - -PRIOR_PROCEDURE_TABLE = "graft_stent" -PRIOR_PROCEDURES = [4296227, 42537730, 762043, 44782770, 42537729] - - -def query_builder(spark_args): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={ - "cad_concept_table": CAD_INCLUSION_TABLE, - "graft_stent_table": PRIOR_PROCEDURE_TABLE, - "date_lower_bound": spark_args.date_lower_bound, - }, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=CAD_INCLUSION_TABLE, - ancestor_concept_ids=CAD_CONCEPTS, - is_standard=True, - ), - AncestorTableSpec( - table_name=PRIOR_PROCEDURE_TABLE, - ancestor_concept_ids=PRIOR_PROCEDURES, - is_standard=True, - ), - ] - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/covid.py b/src/cehrbert/spark_apps/cohorts/covid.py deleted file mode 100644 index 89c7c8be..00000000 --- a/src/cehrbert/spark_apps/cohorts/covid.py +++ /dev/null @@ -1,42 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec - -COVID_COHORT_QUERY = """ -SELECT DISTINCT - c.person_id, - DATE_ADD(FIRST(index_date) OVER (PARTITION BY person_id ORDER BY index_date, visit_occurrence_id), 1) AS index_date, - FIRST(visit_occurrence_id) OVER (PARTITION BY person_id ORDER BY index_date, visit_occurrence_id) AS visit_occurrence_id -FROM -( - SELECT DISTINCT - m.person_id, - FIRST(visit_start_date) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS index_date, - FIRST(v.visit_occurrence_id) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.measurement AS m - JOIN global_temp.visit_occurrence AS v - ON m.visit_occurrence_id = v.visit_occurrence_id - JOIN global_temp.concept AS c - ON m.value_as_concept_id = c.concept_id - WHERE m.measurement_concept_id IN (723475,723479,706178,723473,723474,586515,706177,706163,706180,706181) - AND c.concept_name IN ('Detected', 'Positve') - - UNION - - SELECT - co.person_id, - FIRST(visit_start_date) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS index_date, - FIRST(v.visit_occurrence_id) OVER (PARTITION BY v.person_id ORDER BY visit_start_date, v.visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.condition_occurrence AS co - JOIN global_temp.visit_occurrence AS v - ON co.visit_occurrence_id = v.visit_occurrence_id - WHERE co.condition_concept_id = 37311061 -) c -""" - -DEFAULT_COHORT_NAME = "covid19" -DEPENDENCY_LIST = ["person", "visit_occurrence", "measurement", "condition_occurrence"] - - -def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=COVID_COHORT_QUERY, parameters={}) - - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/covid_inpatient.py b/src/cehrbert/spark_apps/cohorts/covid_inpatient.py deleted file mode 100644 index 33ee1eca..00000000 --- a/src/cehrbert/spark_apps/cohorts/covid_inpatient.py +++ /dev/null @@ -1,83 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec - -COVID_COHORT_QUERY = """ -WITH covid_positive AS -( - - SELECT DISTINCT - ROW_NUMBER() OVER(ORDER BY c.person_id, c.index_date) AS test_row_number, - c.* - FROM - ( - SELECT DISTINCT - m.person_id, - COALESCE(vo.visit_start_datetime, m.measurement_datetime) AS index_date, - vo.visit_occurrence_id, - vo.visit_concept_id - FROM global_temp.measurement AS m - LEFT JOIN global_temp.visit_occurrence AS vo - ON m.visit_occurrence_id = vo.visit_occurrence_id - WHERE measurement_concept_id IN (723475,723479,706178,723473,723474,586515,706177,706163,706180,706181) - AND value_source_value = 'Detected' - - UNION - - SELECT DISTINCT - co.person_id, - COALESCE(vo.visit_start_datetime, co.condition_start_datetime) AS index_date, - vo.visit_occurrence_id, - vo.visit_concept_id - FROM global_temp.condition_occurrence AS co - LEFT JOIN global_temp.visit_occurrence AS vo - ON co.visit_occurrence_id = vo.visit_occurrence_id - WHERE condition_concept_id = 37311061 - ) c -), - -covid_test_with_no_visit AS -( - SELECT DISTINCT - c.test_row_number, - c.person_id, - FIRST_VALUE(vo.visit_start_datetime) OVER(PARTITION BY c.person_id ORDER BY vo.visit_start_datetime DESC) AS index_date, - FIRST_VALUE(vo.visit_occurrence_id) OVER(PARTITION BY c.person_id ORDER BY vo.visit_start_datetime DESC) AS visit_occurrence_id, - FIRST_VALUE(vo.visit_concept_id) OVER(PARTITION BY c.person_id ORDER BY vo.visit_start_datetime DESC) AS visit_concept_id - FROM covid_positive AS c - JOIN global_temp.visit_occurrence AS vo - ON c.person_id = vo.person_id AND c.index_date BETWEEN DATE_ADD(vo.visit_start_date, -7) AND vo.visit_start_date - WHERE c.visit_occurrence_id IS NULL -), - -all_covid_tests AS -( - SELECT DISTINCT - c.person_id, - COALESCE(c.index_date, cn.index_date) AS index_date, - COALESCE(c.visit_occurrence_id, cn.visit_occurrence_id) AS visit_occurrence_id, - COALESCE(c.visit_concept_id, cn.visit_concept_id) AS visit_concept_id - FROM covid_positive AS c - LEFT JOIN covid_test_with_no_visit AS cn - ON c.test_row_number = cn.test_row_number -) - -SELECT DISTINCT - person_id, - FIRST_VALUE(vo.index_date) OVER(PARTITION BY vo.person_id ORDER BY vo.index_date) AS index_date, - FIRST_VALUE(vo.visit_occurrence_id) OVER(PARTITION BY vo.person_id ORDER BY vo.index_date) AS visit_occurrence_id -FROM -( - SELECT - co.* - FROM all_covid_tests AS co - WHERE visit_concept_id IN (262, 9203, 9201) -) vo -""" - -DEFAULT_COHORT_NAME = "covid19" -DEPENDENCY_LIST = ["person", "visit_occurrence", "measurement", "condition_occurrence"] - - -def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=COVID_COHORT_QUERY, parameters={}) - - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/death.py b/src/cehrbert/spark_apps/cohorts/death.py deleted file mode 100644 index 32ab339a..00000000 --- a/src/cehrbert/spark_apps/cohorts/death.py +++ /dev/null @@ -1,45 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec, create_cohort_entry_query_spec - -DEATH_COHORT_QUERY = """ -WITH max_death_date_cte AS -( - SELECT - person_id, - MAX(death_date) AS death_date - FROM global_temp.death - GROUP BY person_id -), -last_visit_start_date AS -( - SELECT - person_id, - MAX(visit_start_date) AS last_visit_start_date - FROM global_temp.visit_occurrence - GROUP BY person_id -) - -SELECT - d.person_id, - d.death_date AS index_date, - CAST(null AS INT) AS visit_occurrence_id -FROM max_death_date_cte AS d -JOIN last_visit_start_date AS v - ON d.person_id = v.person_id - AND v.last_visit_start_date <= d.death_date -""" - -DEFAULT_COHORT_NAME = "mortality" -DEPENDENCY_LIST = ["person", "death", "visit_occurrence"] - - -def query_builder(): - query = QuerySpec(table_name=DEFAULT_COHORT_NAME, query_template=DEATH_COHORT_QUERY, parameters={}) - - entry_cohort_query = create_cohort_entry_query_spec(entry_query_template=DEATH_COHORT_QUERY, parameters={}) - - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - entry_cohort_query=entry_cohort_query, - ) diff --git a/src/cehrbert/spark_apps/cohorts/heart_failure.py b/src/cehrbert/spark_apps/cohorts/heart_failure.py deleted file mode 100644 index 1362edef..00000000 --- a/src/cehrbert/spark_apps/cohorts/heart_failure.py +++ /dev/null @@ -1,422 +0,0 @@ -from ..cohorts.query_builder import ( - AncestorTableSpec, - QueryBuilder, - QuerySpec, - create_cohort_entry_query_spec, - create_negative_query_spec, -) - -# 1. Incidens of Heart Failure -HEART_FAILURE_CONCEPT = [316139] - -# 2. At least one new or worsening symptoms due to HF -WORSEN_HF_DIAGNOSIS_CONCEPT = [312437, 4263848, 46272935, 4223659, 315361] - -# 3. At least TWO physical examination findings OR one physical examination finding and at least -# ONE laboratory criterion -PHYSICAL_EXAM_CONCEPT = [433595, 200528, 4117930, 4329988, 4289004, 4285133] -## Lab result concept -# https://labtestsonline.org/tests/bnp-and-nt-probnp -BNP_CONCEPT = [ - 4307029, - 3031569, - 3011960, - 3052295, -] # High B-type Natriuretic Peptide (BNP) > 500 pg/mL -NT_PRO_BNP_CONCEPT = [3029187, 42529224, 3029435, 42529225] -PWP_CONCEPT = [ - 1002721, - 4040920, - 21490776, -] # Pulmonary artery wedge pressure >= 18 no patient in cumc -CVP_CONCEPT = [ - 21490675, - 4323687, - 3000333, - 1003995, -] # Central venous pressure >= 12 no patient in cumc -CI_CONCEPT = 21490712 # Cardiac index < 2.2 no patient in cumc - -# 4. At least ONE of the treatments specifically for HF -DRUG_CONCEPT = [ - 956874, - 942350, - 987406, - 932745, - 1309799, - 970250, - 992590, - 907013, - 1942960, -] -MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT = [ - 45888564, - 4052536, - 4337306, - 2107514, - 45889695, - 2107500, - 45887675, - 43527920, - 2107501, - 45890116, - 40756954, - 4338594, - 43527923, - 40757060, - 2100812, -] -DIALYSIS_CONCEPT = [4032243, 45889365] -ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT = [ - 4144390, - 4150347, - 4281764, - 725038, - 725037, - 2100816, - 2100822, - 725039, - 2100828, - 4337306, - 4140024, - 4146121, - 4060257, - 4309033, - 4222272, - 4243758, - 4241906, - 4080968, - 4224193, - 4052537, - 4050864, -] - -DIURETIC_CONCEPT_ID = [4186999] - -ROLL_UP_DIURETICS_TO_INGREDIENT_TEMPLATE = """ -SELECT DISTINCT - c.* -FROM global_temp.diuretics_ancestor_table AS a -JOIN global_temp.concept_relationship AS cr - ON a.descendant_concept_id = cr.concept_id_1 AND cr.relationship_id = 'Maps to' -JOIN global_temp.concept_ancestor AS ca - ON cr.concept_id_2 = ca.descendant_concept_id -JOIN global_temp.concept AS c - ON ca.ancestor_concept_id = c.concept_id -WHERE c.concept_class_id = 'Ingredient' -""" - -HEART_FAILURE_ENTRY_COHORT = """ -WITH hf_conditions AS ( -SELECT - * -FROM global_temp.condition_occurrence AS co -JOIN global_temp.{hf_concept} AS hf -ON co.condition_concept_id = hf.concept_id -) - -SELECT - c.person_id, - c.earliest_visit_start_date AS index_date, - c.earliest_visit_occurrence_id AS visit_occurrence_id, - COUNT(c.visit_occurrence_id) OVER(PARTITION BY c.person_id) AS num_of_diagnosis -FROM -( - SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id - ORDER BY DATE(c.condition_start_date)) AS earliest_condition_start_date, - first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date, - first(v.visit_occurrence_id) OVER (PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date)) AS earliest_visit_occurrence_id - FROM global_temp.visit_occurrence AS v - JOIN hf_conditions AS c - ON v.visit_occurrence_id = c.visit_occurrence_id -) c -WHERE c.earliest_visit_start_date <= c.earliest_condition_start_date -""" - -HEART_FAILURE_INTERMEDIATE_COHORT_QUERY = """ -WITH hf_conditions AS ( - SELECT - * - FROM global_temp.condition_occurrence AS co - JOIN global_temp.{hf_concept} AS hf - ON co.condition_concept_id = hf.concept_id -), - -worsen_hf_diagnosis AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.condition_occurrence AS co - JOIN global_temp.{worsen_hf_dx_concepts} AS w_hf - ON co.condition_concept_id = w_hf.concept_id -), - -phy_exam_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.condition_occurrence AS co - JOIN global_temp.{phy_exam_concepts} AS phy - ON co.condition_concept_id = phy.concept_id -), - -bnp_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.measurement AS m - JOIN global_temp.{bnp_concepts} AS bnp - ON m.measurement_concept_id = bnp.concept_id - AND m.value_source_value > 500 - UNION ALL - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.measurement AS m - JOIN global_temp.{nt_pro_bnp_concepts} AS nt_bnp - ON m.measurement_concept_id = nt_bnp.concept_id - AND m.value_source_value > 2000 -), - -drug_concepts AS ( - SELECT DISTINCT - * - FROM - ( - SELECT * - FROM global_temp.{drug_concepts} - - UNION - - SELECT * - FROM global_temp.diuretics_concepts - ) d -), - -drug_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.drug_exposure AS d - JOIN drug_concepts AS dc - ON d.drug_concept_id = dc.concept_id -), - -mechanical_support_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.procedure_occurrence AS p - JOIN global_temp.{mechanical_support_concepts} AS msc - ON p.procedure_concept_id = msc.concept_id -), - -dialysis_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.procedure_occurrence AS p - JOIN global_temp.{dialysis_concepts} AS dc - ON p.procedure_concept_id = dc.concept_id -), - -artificial_heart_cohort AS ( - SELECT DISTINCT person_id, visit_occurrence_id - FROM global_temp.procedure_occurrence AS p - JOIN global_temp.{artificial_heart_concepts} AS ahc - ON p.procedure_concept_id = ahc.concept_id -), - -treatment_cohort AS ( --- SELECT * FROM drug_cohort --- UNION ALL - SELECT * FROM mechanical_support_cohort - UNION ALL - SELECT * FROM dialysis_cohort - UNION ALL - SELECT * FROM artificial_heart_cohort -), - -entry_cohort AS ( - SELECT - c.person_id, - c.earliest_visit_start_date AS index_date, - c.earliest_visit_occurrence_id AS visit_occurrence_id, - COUNT(c.visit_occurrence_id) OVER(PARTITION BY c.person_id) AS num_of_diagnosis - FROM - ( - SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - first(DATE(c.condition_start_date)) OVER (PARTITION BY v.person_id - ORDER BY DATE(c.condition_start_date)) AS earliest_condition_start_date, - first(DATE(v.visit_start_date)) OVER (PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date, - first(v.visit_occurrence_id) OVER (PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date)) AS earliest_visit_occurrence_id - FROM global_temp.visit_occurrence AS v - JOIN hf_conditions AS c - ON v.visit_occurrence_id = c.visit_occurrence_id - ) c - WHERE c.earliest_visit_start_date <= c.earliest_condition_start_date -) - -SELECT - c.*, - CAST(COALESCE(bnp.person_id, tc.person_id, dc.person_id) IS NOT NULL AS INT) AS inclusion -FROM entry_cohort AS c -LEFT JOIN ( -SELECT DISTINCT person_id FROM bnp_cohort -) AS bnp - ON c.person_id = bnp.person_id -LEFT JOIN ( - SELECT DISTINCT - person_id - FROM treatment_cohort -) AS tc - ON c.person_id = tc.person_id -LEFT JOIN ( - SELECT DISTINCT - hf.person_id - FROM hf_conditions hf - JOIN drug_cohort dc - ON hf.visit_occurrence_id = dc.visit_occurrence_id -) AS dc - ON c.person_id = dc.person_id -""" - -HEART_FAILURE_COHORT_QUERY = """ -SELECT - person_id, - index_date, - visit_occurrence_id -FROM global_temp.{intermediate_heart_failure} -WHERE inclusion = {inclusion} -""" - -DEPENDENCY_LIST = [ - "person", - "condition_occurrence", - "visit_occurrence", - "drug_exposure", - "measurement", - "procedure_occurrence", -] -HEART_FAILURE_CONCEPT_TABLE = "hf_concept" -WORSEN_HF_DX_CONCEPT_TABLE = "worsen_hf_dx_concepts" -PHYSICAL_EXAM_COHORT_TABLE = "phy_exam_concepts" -BNP_CONCEPT_TABLE = "bnp_concepts" -NT_PRO_BNP_CONCEPT_TABLE = "nt_pro_bnp_concepts" -DRUG_CONCEPT_TABLE = "drug_concepts" -MECHANICAL_SUPPORT_CONCEPT_TABLE = "mechanical_support_concepts" -DIALYSIS_CONCEPT_TABLE = "dialysis_concepts" -ARTIFICIAL_HEART_CONCEPT_TABLE = "artificial_heart_concepts" - -DIURETICS_ANCESTOR_TABLE = "diuretics_ancestor_table" -DIURETICS_INGREDIENT_CONCEPTS = "diuretics_concepts" - -INTERMEDIATE_COHORT_NAME = "intermediate_heart_failure" -DEFAULT_COHORT_NAME = "heart_failure" -NEGATIVE_COHORT_NAME = "negative_heart_failure" - - -def query_builder(): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=HEART_FAILURE_COHORT_QUERY, - parameters={ - "intermediate_heart_failure": INTERMEDIATE_COHORT_NAME, - "inclusion": 1, - }, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=HEART_FAILURE_CONCEPT_TABLE, - ancestor_concept_ids=HEART_FAILURE_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=WORSEN_HF_DX_CONCEPT_TABLE, - ancestor_concept_ids=WORSEN_HF_DIAGNOSIS_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=PHYSICAL_EXAM_COHORT_TABLE, - ancestor_concept_ids=PHYSICAL_EXAM_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=BNP_CONCEPT_TABLE, - ancestor_concept_ids=BNP_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=NT_PRO_BNP_CONCEPT_TABLE, - ancestor_concept_ids=NT_PRO_BNP_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=DRUG_CONCEPT_TABLE, - ancestor_concept_ids=DRUG_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=MECHANICAL_SUPPORT_CONCEPT_TABLE, - ancestor_concept_ids=MECHANICAL_CIRCULATORY_SUPPORT_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=DIALYSIS_CONCEPT_TABLE, - ancestor_concept_ids=DIALYSIS_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=ARTIFICIAL_HEART_CONCEPT_TABLE, - ancestor_concept_ids=ARTIFICIAL_HEART_ASSOCIATED_PROCEDURE_CONCEPT, - is_standard=True, - ), - AncestorTableSpec( - table_name=DIURETICS_ANCESTOR_TABLE, - ancestor_concept_ids=DIURETIC_CONCEPT_ID, - is_standard=False, - ), - ] - - dependency_queries = [ - QuerySpec( - table_name=DIURETICS_INGREDIENT_CONCEPTS, - query_template=ROLL_UP_DIURETICS_TO_INGREDIENT_TEMPLATE, - parameters={}, - ), - QuerySpec( - table_name=INTERMEDIATE_COHORT_NAME, - query_template=HEART_FAILURE_INTERMEDIATE_COHORT_QUERY, - parameters={ - "hf_concept": HEART_FAILURE_CONCEPT_TABLE, - "worsen_hf_dx_concepts": WORSEN_HF_DX_CONCEPT_TABLE, - "phy_exam_concepts": PHYSICAL_EXAM_COHORT_TABLE, - "bnp_concepts": BNP_CONCEPT_TABLE, - "nt_pro_bnp_concepts": NT_PRO_BNP_CONCEPT_TABLE, - "drug_concepts": DRUG_CONCEPT_TABLE, - "mechanical_support_concepts": MECHANICAL_SUPPORT_CONCEPT_TABLE, - "dialysis_concepts": DIALYSIS_CONCEPT_TABLE, - "artificial_heart_concepts": ARTIFICIAL_HEART_CONCEPT_TABLE, - }, - ), - ] - - entry_cohort_query = create_cohort_entry_query_spec( - entry_query_template=HEART_FAILURE_ENTRY_COHORT, - parameters={"hf_concept": HEART_FAILURE_CONCEPT_TABLE}, - ) - - negative_query = create_negative_query_spec( - entry_query_template=HEART_FAILURE_COHORT_QUERY, - parameters={ - "intermediate_heart_failure": INTERMEDIATE_COHORT_NAME, - "inclusion": 0, - }, - ) - - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - query=query, - negative_query=negative_query, - entry_cohort_query=entry_cohort_query, - dependency_list=DEPENDENCY_LIST, - dependency_queries=dependency_queries, - post_queries=[], - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py b/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py deleted file mode 100644 index f24f3cfc..00000000 --- a/src/cehrbert/spark_apps/cohorts/ischemic_stroke.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec - -COHORT_QUERY_TEMPLATE = """ -SELECT - co.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id -FROM global_temp.condition_occurrence AS co -JOIN global_temp.visit_occurrence AS vo - ON co.visit_occurrence_id = vo.visit_occurrence_id -JOIN global_temp.{ischemic_stroke_concepts} AS c - ON co.condition_concept_id = c.concept_id -""" - -ISCHEMIC_STROKE_CONCEPT_ID = [443454] - -DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] - -DEFAULT_COHORT_NAME = "ischemic_stroke" -ISCHEMIC_STROKE_CONCEPTS = "ischemic_stroke_concepts" - - -def query_builder(): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={"ischemic_stroke_concepts": ISCHEMIC_STROKE_CONCEPTS}, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=ISCHEMIC_STROKE_CONCEPTS, - ancestor_concept_ids=ISCHEMIC_STROKE_CONCEPT_ID, - is_standard=True, - ) - ] - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py b/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py deleted file mode 100644 index 5d32112e..00000000 --- a/src/cehrbert/spark_apps/cohorts/last_visit_discharged_home.py +++ /dev/null @@ -1,34 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec - -COHORT_QUERY = """ -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.index_date -FROM -( - SELECT - v.person_id, - v.visit_occurrence_id, - v.visit_end_date AS index_date, - v.discharged_to_concept_id, - ROW_NUMBER() OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_end_date) DESC) AS rn - FROM global_temp.visit_occurrence AS v - WHERE v.visit_concept_id IN (9201, 262) --inpatient, er-inpatient - AND v.visit_end_date IS NOT NULL - AND v.discharged_to_concept_id = 8536 --discharge to home -) AS v - WHERE v.rn = 1 AND v.index_date >= '{date_lower_bound}' -""" - -DEPENDENCY_LIST = ["person", "visit_occurrence"] -DEFAULT_COHORT_NAME = "last_visit_discharge_home" - - -def query_builder(spark_args): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY, - parameters={"date_lower_bound": spark_args.date_lower_bound}, - ) - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/cohorts/query_builder.py b/src/cehrbert/spark_apps/cohorts/query_builder.py deleted file mode 100644 index ec125e0f..00000000 --- a/src/cehrbert/spark_apps/cohorts/query_builder.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -from abc import ABC -from typing import List, NamedTuple - -ENTRY_COHORT = "entry_cohort" -NEGATIVE_COHORT = "negative_cohort" - - -def create_cohort_entry_query_spec(entry_query_template, parameters): - return QuerySpec( - table_name=ENTRY_COHORT, - query_template=entry_query_template, - parameters=parameters, - ) - - -def create_negative_query_spec(entry_query_template, parameters): - return QuerySpec( - table_name=NEGATIVE_COHORT, - query_template=entry_query_template, - parameters=parameters, - ) - - -class QuerySpec(NamedTuple): - query_template: str - parameters: dict - table_name: str - - def __str__(self): - return f"table={self.table_name}\n" f"query={self.query_template.format(**self.parameters)}\n" - - -class AncestorTableSpec(NamedTuple): - ancestor_concept_ids: List[int] - table_name: str - is_standard: bool - - def __str__(self): - return ( - f"table_name={self.table_name}\n" - f"ancestor_concept_ids={self.ancestor_concept_ids}\n" - f"is_standard={self.is_standard}\n" - ) - - -class QueryBuilder(ABC): - - def __init__( - self, - cohort_name: str, - dependency_list: List[str], - query: QuerySpec, - negative_query: QuerySpec = None, - entry_cohort_query: QuerySpec = None, - dependency_queries: List[QuerySpec] = None, - post_queries: List[QuerySpec] = None, - ancestor_table_specs: List[AncestorTableSpec] = None, - ): - """ - :param cohort_name: - - :param query: - :param dependency_queries: - :param post_queries: - :param dependency_list: - :param ancestor_table_specs: - """ - self._cohort_name = cohort_name - self._query = query - self._negative_query = negative_query - self._entry_cohort_query = entry_cohort_query - self._dependency_queries = dependency_queries - self._post_queries = post_queries - self._dependency_list = dependency_list - self._ancestor_table_specs = ancestor_table_specs - - self.get_logger().info( - f"cohort_name: {cohort_name}\n" - f"post_queries: {post_queries}\n" - f"entry_cohort: {entry_cohort_query}\n" - f"dependency_queries: {dependency_queries}\n" - f"dependency_list: {dependency_list}\n" - f"ancestor_table_specs: {ancestor_table_specs}\n" - f"query: {query}\n" - f"negative_query: {negative_query}\n" - ) - - def get_dependency_queries(self): - """ - Instantiate table dependencies in spark for. - - :return: - """ - return self._dependency_queries - - def get_entry_cohort_query(self): - """ - Queryspec for Instantiating the entry cohort in spark context. - - :return: - """ - return self._entry_cohort_query - - def get_query(self): - """ - Create a query that can be executed by spark.sql. - - :return: - """ - return self._query - - def get_negative_query(self): - """ - Return the negative query that can be executed by spark.sql. - - :return: - """ - return self._negative_query - - def get_post_process_queries(self): - """ - Get a list of post process queries to process the cohort. - - :return: - """ - return self._post_queries - - def get_dependency_list(self): - """ - Get a list of tables that are required for this cohort. - - :return: - """ - return self._dependency_list - - def get_cohort_name(self): - return self._cohort_name - - def get_ancestor_table_specs(self): - """ - Create the descendant table for the provided ancestor_table_specs. - - :return: - """ - return self._ancestor_table_specs - - def __str__(self): - return f"{str(self.__class__.__name__)} for {self.get_cohort_name()}" - - @classmethod - def get_logger(cls): - return logging.getLogger(cls.__name__) diff --git a/src/cehrbert/spark_apps/cohorts/spark_app_base.py b/src/cehrbert/spark_apps/cohorts/spark_app_base.py deleted file mode 100644 index 9c920411..00000000 --- a/src/cehrbert/spark_apps/cohorts/spark_app_base.py +++ /dev/null @@ -1,745 +0,0 @@ -import os -import re -import shutil -from abc import ABC - -from pandas import to_datetime -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.window import Window - -from ...utils.spark_utils import ( - VISIT_OCCURRENCE, - AttType, - F, - List, - W, - build_ancestry_table_for, - create_concept_frequency_data, - create_hierarchical_sequence_data, - create_sequence_data, - create_sequence_data_with_att, - extract_ehr_records, - get_descendant_concept_ids, - logging, - preprocess_domain_table, -) -from ..cohorts.query_builder import ENTRY_COHORT, NEGATIVE_COHORT, QueryBuilder - -COHORT_TABLE_NAME = "cohort" -PERSON = "person" -OBSERVATION_PERIOD = "observation_period" -DEFAULT_DEPENDENCY = [ - "person", - "visit_occurrence", - "observation_period", - "concept", - "concept_ancestor", - "concept_relationship", -] - - -def cohort_validator(required_columns_attribute): - """ - Decorator for validating the cohort dataframe returned by build function in. - - AbstractCohortBuilderBase - :param required_columns_attribute: attribute for storing cohort_required_columns - in :class:`spark_apps.spark_app_base.AbstractCohortBuilderBase` - :return: - """ - - def cohort_validator_decorator(function): - def wrapper(self, *args, **kwargs): - cohort = function(self, *args, **kwargs) - required_columns = getattr(self, required_columns_attribute) - for required_column in required_columns: - if required_column not in cohort.columns: - raise AssertionError(f"{required_column} is a required column in the cohort") - return cohort - - return wrapper - - return cohort_validator_decorator - - -def instantiate_dependencies(spark, input_folder, dependency_list): - dependency_dict = dict() - for domain_table_name in dependency_list + DEFAULT_DEPENDENCY: - table = preprocess_domain_table(spark, input_folder, domain_table_name) - table.createOrReplaceGlobalTempView(domain_table_name) - dependency_dict[domain_table_name] = table - return dependency_dict - - -def validate_date_folder(input_folder, table_list): - for domain_table_name in table_list: - parquet_file_path = os.path.join(input_folder, domain_table_name) - if not os.path.exists(parquet_file_path): - raise FileExistsError(f"{parquet_file_path} does not exist") - - -def validate_folder(folder_path): - if not os.path.exists(folder_path): - raise FileExistsError(f"{folder_path} does not exist") - - -class BaseCohortBuilder(ABC): - cohort_required_columns = ["person_id", "index_date", "visit_occurrence_id"] - - def __init__( - self, - query_builder: QueryBuilder, - input_folder: str, - output_folder: str, - date_lower_bound: str, - date_upper_bound: str, - age_lower_bound: int, - age_upper_bound: int, - prior_observation_period: int, - post_observation_period: int, - ): - - self._query_builder = query_builder - self._input_folder = input_folder - self._output_folder = output_folder - self._date_lower_bound = date_lower_bound - self._date_upper_bound = date_upper_bound - self._age_lower_bound = age_lower_bound - self._age_upper_bound = age_upper_bound - self._prior_observation_period = prior_observation_period - self._post_observation_period = post_observation_period - cohort_name = re.sub("[^a-z0-9]+", "_", self._query_builder.get_cohort_name().lower()) - self._output_data_folder = os.path.join(self._output_folder, cohort_name) - - self.get_logger().info( - f"query_builder: {query_builder}\n" - f"input_folder: {input_folder}\n" - f"output_folder: {output_folder}\n" - f"date_lower_bound: {date_lower_bound}\n" - f"date_upper_bound: {date_upper_bound}\n" - f"age_lower_bound: {age_lower_bound}\n" - f"age_upper_bound: {age_upper_bound}\n" - f"prior_observation_period: {prior_observation_period}\n" - f"post_observation_period: {post_observation_period}\n" - ) - - # Validate the age range, observation_window and prediction_window - self._validate_integer_inputs() - # Validate the input and output folders - validate_folder(self._input_folder) - validate_folder(self._output_folder) - # Validate if the data folders exist - validate_date_folder(self._input_folder, self._query_builder.get_dependency_list()) - - self.spark = SparkSession.builder.appName(f"Generate {self._query_builder.get_cohort_name()}").getOrCreate() - - self._dependency_dict = instantiate_dependencies( - self.spark, self._input_folder, self._query_builder.get_dependency_list() - ) - - @cohort_validator("cohort_required_columns") - def create_cohort(self): - """ - Create cohort. - - :return: - """ - # Build the ancestor tables for the main query to use if the ancestor_table_specs are - # available - if self._query_builder.get_ancestor_table_specs(): - for ancestor_table_spec in self._query_builder.get_ancestor_table_specs(): - func = get_descendant_concept_ids if ancestor_table_spec.is_standard else build_ancestry_table_for - ancestor_table = func(self.spark, ancestor_table_spec.ancestor_concept_ids) - ancestor_table.createOrReplaceGlobalTempView(ancestor_table_spec.table_name) - - # Build the dependencies for the main query to use if the dependency_queries are available - if self._query_builder.get_dependency_queries(): - for dependency_query in self._query_builder.get_dependency_queries(): - query = dependency_query.query_template.format(**dependency_query.parameters) - dependency_table = self.spark.sql(query) - dependency_table.createOrReplaceGlobalTempView(dependency_query.table_name) - - # Build the dependency for the entry cohort if exists - if self._query_builder.get_entry_cohort_query(): - entry_cohort_query = self._query_builder.get_entry_cohort_query() - query = entry_cohort_query.query_template.format(**entry_cohort_query.parameters) - dependency_table = self.spark.sql(query) - dependency_table.createOrReplaceGlobalTempView(entry_cohort_query.table_name) - - # Build the negative cohort if exists - if self._query_builder.get_negative_query(): - negative_cohort_query = self._query_builder.get_negative_query() - query = negative_cohort_query.query_template.format(**negative_cohort_query.parameters) - dependency_table = self.spark.sql(query) - dependency_table.createOrReplaceGlobalTempView(negative_cohort_query.table_name) - - main_query = self._query_builder.get_query() - cohort = self.spark.sql(main_query.query_template.format(**main_query.parameters)) - cohort.createOrReplaceGlobalTempView(main_query.table_name) - - # Post process the cohort if the post_process_queries are available - if self._query_builder.get_post_process_queries(): - for post_query in self._query_builder.get_post_process_queries(): - cohort = self.spark.sql(post_query.query_template.format(**post_query.parameters)) - cohort.createOrReplaceGlobalTempView(main_query.table_name) - - return cohort - - def build(self): - """Build the cohort and write the dataframe as parquet files to _output_data_folder.""" - cohort = self.create_cohort() - - cohort = self._apply_observation_period(cohort) - - cohort = self._add_demographics(cohort) - - cohort = cohort.where(F.col("age").between(self._age_lower_bound, self._age_upper_bound)).where( - F.col("index_date").between(to_datetime(self._date_lower_bound), to_datetime(self._date_upper_bound)) - ) - - cohort.write.mode("overwrite").parquet(self._output_data_folder) - - return self - - def load_cohort(self): - return self.spark.read.parquet(self._output_data_folder) - - @cohort_validator("cohort_required_columns") - def _apply_observation_period(self, cohort: DataFrame): - cohort.createOrReplaceGlobalTempView("cohort") - - qualified_cohort = self.spark.sql( - """ - SELECT - c.* - FROM global_temp.cohort AS c - JOIN global_temp.observation_period AS p - ON c.person_id = p.person_id - AND DATE_ADD(c.index_date, -{prior_observation_period}) >= p.observation_period_start_date - AND DATE_ADD(c.index_date, {post_observation_period}) <= p.observation_period_end_date - """.format( - prior_observation_period=self._prior_observation_period, - post_observation_period=self._post_observation_period, - ) - ) - - self.spark.sql(f"DROP VIEW global_temp.cohort") - return qualified_cohort - - @cohort_validator("cohort_required_columns") - def _add_demographics(self, cohort: DataFrame): - return ( - cohort.join(self._dependency_dict[PERSON], "person_id") - .withColumn("age", F.year("index_date") - F.col("year_of_birth")) - .select( - F.col("person_id"), - F.col("age"), - F.col("gender_concept_id"), - F.col("race_concept_id"), - F.col("index_date"), - F.col("visit_occurrence_id"), - ) - .distinct() - ) - - def _validate_integer_inputs(self): - assert self._age_lower_bound >= 0 - assert self._age_upper_bound > 0 - assert self._age_lower_bound < self._age_upper_bound - assert self._prior_observation_period >= 0 - assert self._post_observation_period >= 0 - - @classmethod - def get_logger(cls): - return logging.getLogger(cls.__name__) - - -class NestedCohortBuilder: - def __init__( - self, - cohort_name: str, - input_folder: str, - output_folder: str, - target_cohort: DataFrame, - outcome_cohort: DataFrame, - ehr_table_list: List[str], - observation_window: int, - hold_off_window: int, - prediction_start_days: int, - prediction_window: int, - num_of_visits: int, - num_of_concepts: int, - patient_splits_folder: str = None, - is_window_post_index: bool = False, - include_visit_type: bool = True, - allow_measurement_only: bool = False, - exclude_visit_tokens: bool = False, - is_feature_concept_frequency: bool = False, - is_roll_up_concept: bool = False, - include_concept_list: bool = True, - is_new_patient_representation: bool = False, - gpt_patient_sequence: bool = False, - is_hierarchical_bert: bool = False, - classic_bert_seq: bool = False, - is_first_time_outcome: bool = False, - is_questionable_outcome_existed: bool = False, - is_remove_index_prediction_starts: bool = False, - is_prediction_window_unbounded: bool = False, - is_observation_window_unbounded: bool = False, - is_population_estimation: bool = False, - att_type: AttType = AttType.CEHR_BERT, - exclude_demographic: bool = True, - use_age_group: bool = False, - single_contribution: bool = False, - ): - self._cohort_name = cohort_name - self._input_folder = input_folder - self._output_folder = output_folder - self._patient_splits_folder = patient_splits_folder - self._target_cohort = target_cohort - self._outcome_cohort = outcome_cohort - self._ehr_table_list = ehr_table_list - self._observation_window = observation_window - self._hold_off_window = hold_off_window - self._prediction_start_days = prediction_start_days - self._prediction_window = prediction_window - self._num_of_visits = num_of_visits - self._num_of_concepts = num_of_concepts - self._is_observation_post_index = is_window_post_index - self._is_observation_window_unbounded = is_observation_window_unbounded - self._include_visit_type = include_visit_type - self._exclude_visit_tokens = exclude_visit_tokens - self._classic_bert_seq = classic_bert_seq - self._is_feature_concept_frequency = is_feature_concept_frequency - self._is_roll_up_concept = is_roll_up_concept - self._is_new_patient_representation = is_new_patient_representation - self._gpt_patient_sequence = gpt_patient_sequence - self._is_hierarchical_bert = is_hierarchical_bert - self._is_first_time_outcome = is_first_time_outcome - self._is_remove_index_prediction_starts = is_remove_index_prediction_starts - self._is_questionable_outcome_existed = is_questionable_outcome_existed - self._is_prediction_window_unbounded = is_prediction_window_unbounded - self._include_concept_list = include_concept_list - self._allow_measurement_only = allow_measurement_only - self._output_data_folder = os.path.join( - self._output_folder, re.sub("[^a-z0-9]+", "_", self._cohort_name.lower()) - ) - self._is_population_estimation = is_population_estimation - self._att_type = att_type - self._exclude_demographic = exclude_demographic - self._use_age_group = use_age_group - self._single_contribution = single_contribution - - self.get_logger().info( - f"cohort_name: {cohort_name}\n" - f"input_folder: {input_folder}\n" - f"output_folder: {output_folder}\n" - f"ehr_table_list: {ehr_table_list}\n" - f"observation_window: {observation_window}\n" - f"prediction_start_days: {prediction_start_days}\n" - f"prediction_window: {prediction_window}\n" - f"hold_off_window: {hold_off_window}\n" - f"num_of_visits: {num_of_visits}\n" - f"num_of_concepts: {num_of_concepts}\n" - f"is_window_post_index: {is_window_post_index}\n" - f"include_visit_type: {include_visit_type}\n" - f"exclude_visit_tokens: {exclude_visit_tokens}\n" - f"allow_measurement_only: {allow_measurement_only}\n" - f"is_feature_concept_frequency: {is_feature_concept_frequency}\n" - f"is_roll_up_concept: {is_roll_up_concept}\n" - f"is_new_patient_representation: {is_new_patient_representation}\n" - f"gpt_patient_sequence: {gpt_patient_sequence}\n" - f"is_hierarchical_bert: {is_hierarchical_bert}\n" - f"is_first_time_outcome: {is_first_time_outcome}\n" - f"is_questionable_outcome_existed: {is_questionable_outcome_existed}\n" - f"is_remove_index_prediction_starts: {is_remove_index_prediction_starts}\n" - f"is_prediction_window_unbounded: {is_prediction_window_unbounded}\n" - f"include_concept_list: {include_concept_list}\n" - f"is_observation_window_unbounded: {is_observation_window_unbounded}\n" - f"is_population_estimation: {is_population_estimation}\n" - f"att_type: {att_type}\n" - f"exclude_demographic: {exclude_demographic}\n" - f"use_age_group: {use_age_group}\n" - f"single_contribution: {single_contribution}\n" - ) - - self.spark = SparkSession.builder.appName(f"Generate {self._cohort_name}").getOrCreate() - self._dependency_dict = instantiate_dependencies(self.spark, self._input_folder, DEFAULT_DEPENDENCY) - - # Validate the input and output folders - validate_folder(self._input_folder) - validate_folder(self._output_folder) - # Validate if the data folders exist - validate_date_folder(self._input_folder, self._ehr_table_list) - - def build(self): - self._target_cohort.createOrReplaceGlobalTempView("target_cohort") - self._outcome_cohort.createOrReplaceGlobalTempView("outcome_cohort") - - prediction_start_days = self._prediction_start_days - prediction_window = self._prediction_window - - if self._is_observation_post_index: - prediction_start_days += self._observation_window + self._hold_off_window - prediction_window += self._observation_window + self._hold_off_window - - if self._is_first_time_outcome: - target_cohort = self.spark.sql( - """ - SELECT - t.person_id AS cohort_member_id, - t.* - FROM global_temp.target_cohort AS t - LEFT JOIN global_temp.{entry_cohort} AS o - ON t.person_id = o.person_id - AND DATE_ADD(t.index_date, {prediction_start_days}) > o.index_date - WHERE o.person_id IS NULL - """.format( - entry_cohort=ENTRY_COHORT, - prediction_start_days=prediction_start_days, - ) - ) - target_cohort.createOrReplaceGlobalTempView("target_cohort") - - if self._is_questionable_outcome_existed: - target_cohort = self.spark.sql( - """ - SELECT - t.* - FROM global_temp.target_cohort AS t - LEFT JOIN global_temp.{questionnation_outcome_cohort} AS o - ON t.person_id = o.person_id - WHERE o.person_id IS NULL - """.format( - questionnation_outcome_cohort=NEGATIVE_COHORT - ) - ) - target_cohort.createOrReplaceGlobalTempView("target_cohort") - if self._is_remove_index_prediction_starts: - # Remove the patients whose outcome date lies between index_date and index_date + - # prediction_start_days - target_cohort = self.spark.sql( - """ - SELECT DISTINCT - t.* - FROM global_temp.target_cohort AS t - LEFT JOIN global_temp.outcome_cohort AS exclusion - ON t.person_id = exclusion.person_id - AND exclusion.index_date BETWEEN t.index_date - AND DATE_ADD(t.index_date, {prediction_start_days}) - WHERE exclusion.person_id IS NULL - """.format( - prediction_start_days=max(prediction_start_days - 1, 0) - ) - ) - target_cohort.createOrReplaceGlobalTempView("target_cohort") - - if self._is_prediction_window_unbounded: - query_template = """ - SELECT DISTINCT - t.*, - o.index_date as outcome_date, - CAST(ISNOTNULL(o.person_id) AS INT) AS label - FROM global_temp.target_cohort AS t - LEFT JOIN global_temp.outcome_cohort AS o - ON t.person_id = o.person_id - AND o.index_date >= DATE_ADD(t.index_date, {prediction_start_days}) - """ - else: - query_template = """ - SELECT DISTINCT - t.*, - o.index_date as outcome_date, - CAST(ISNOTNULL(o.person_id) AS INT) AS label - FROM global_temp.target_cohort AS t - LEFT JOIN global_temp.observation_period AS op - ON t.person_id = op.person_id - AND DATE_ADD(t.index_date, {prediction_window}) <= op.observation_period_end_date - LEFT JOIN global_temp.outcome_cohort AS o - ON t.person_id = o.person_id - AND o.index_date BETWEEN DATE_ADD(t.index_date, {prediction_start_days}) - AND DATE_ADD(t.index_date, {prediction_window}) - WHERE op.person_id IS NOT NULL OR o.person_id IS NOT NULL - """ - - cohort_member_id_udf = F.dense_rank().over(W.orderBy("person_id", "index_date", "visit_occurrence_id")) - cohort = self.spark.sql( - query_template.format( - prediction_start_days=prediction_start_days, - prediction_window=prediction_window, - ) - ).withColumn("cohort_member_id", cohort_member_id_udf) - - # Keep one record in case that there are multiple samples generated for the same index_date. - # This should not happen in theory, this is really just a safeguard - row_rank = F.row_number().over( - Window.partitionBy("person_id", "cohort_member_id", "index_date").orderBy(F.desc("label")) - ) - cohort = cohort.withColumn("row_rank", row_rank).where("row_rank == 1").drop("row_rank") - - # We only allow the patient to contribute once to the dataset - # If the patient has any positive outcomes, we will take the most recent positive outcome, - # otherwise we will take the most recent negative outcome - if self._single_contribution: - record_rank = F.row_number().over( - Window.partitionBy("person_id").orderBy(F.desc("label"), F.desc("index_date")) - ) - cohort = cohort.withColumn("record_rank", record_rank).where("record_rank == 1").drop("record_rank") - - ehr_records_for_cohorts = self.extract_ehr_records_for_cohort(cohort) - # ehr_records_for_cohorts.show() - cohort = ( - cohort.join(ehr_records_for_cohorts, ["person_id", "cohort_member_id"]) - .where(F.col("num_of_visits") >= self._num_of_visits) - .where(F.col("num_of_concepts") >= self._num_of_concepts) - ) - - # if patient_splits is provided, we will - if self._patient_splits_folder: - patient_splits = self.spark.read.parquet(self._patient_splits_folder) - cohort.join(patient_splits, "person_id").orderBy("person_id", "cohort_member_id").write.mode( - "overwrite" - ).parquet(os.path.join(self._output_data_folder, "temp")) - # Reload the data from the disk - cohort = self.spark.read.parquet(os.path.join(self._output_data_folder, "temp")) - cohort.where('split="train"').write.mode("overwrite").parquet( - os.path.join(self._output_data_folder, "train") - ) - cohort.where('split="test"').write.mode("overwrite").parquet(os.path.join(self._output_data_folder, "test")) - shutil.rmtree(os.path.join(self._output_data_folder, "temp")) - else: - cohort.orderBy("person_id", "cohort_member_id").write.mode("overwrite").parquet(self._output_data_folder) - - def extract_ehr_records_for_cohort(self, cohort: DataFrame): - """ - Create the patient sequence based on the observation window for the given cohort. - - :param cohort: - :return: - """ - # Extract all ehr records for the patients - ehr_records = extract_ehr_records( - self.spark, - self._input_folder, - self._ehr_table_list, - self._include_visit_type, - self._is_roll_up_concept, - self._include_concept_list, - ) - - # Duplicate the records for cohorts that allow multiple entries - ehr_records = ehr_records.join(cohort, "person_id").select( - [ehr_records[field_name] for field_name in ehr_records.schema.fieldNames()] + ["cohort_member_id"] - ) - - # Only allow the data records that occurred between the index date and the prediction window - if self._is_population_estimation: - if self._is_prediction_window_unbounded: - record_window_filter = ehr_records["date"] <= F.current_date() - else: - record_window_filter = ehr_records["date"] <= F.date_add(cohort["index_date"], self._prediction_window) - else: - # For patient level prediction, we remove all records post index date - if self._is_observation_post_index: - record_window_filter = ehr_records["date"].between( - cohort["index_date"], - F.date_add(cohort["index_date"], self._observation_window), - ) - else: - if self._is_observation_window_unbounded: - record_window_filter = ehr_records["date"] <= F.date_sub( - cohort["index_date"], self._hold_off_window - ) - else: - record_window_filter = ehr_records["date"].between( - F.date_sub( - cohort["index_date"], - self._observation_window + self._hold_off_window, - ), - F.date_sub(cohort["index_date"], self._hold_off_window), - ) - - cohort_ehr_records = ( - ehr_records.join( - cohort, - (ehr_records.person_id == cohort.person_id) & (ehr_records.cohort_member_id == cohort.cohort_member_id), - ) - .where(record_window_filter) - .select([ehr_records[field_name] for field_name in ehr_records.schema.fieldNames()]) - ) - - if self._is_hierarchical_bert: - return create_hierarchical_sequence_data( - person=self._dependency_dict[PERSON], - visit_occurrence=self._dependency_dict[VISIT_OCCURRENCE], - patient_events=cohort_ehr_records, - allow_measurement_only=self._allow_measurement_only, - ) - - if self._is_feature_concept_frequency: - return create_concept_frequency_data(cohort_ehr_records, None) - - if self._is_new_patient_representation: - birthdate_udf = F.coalesce( - "birth_datetime", - F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), - ) - patient_demographic = self._dependency_dict[PERSON].select( - "person_id", - birthdate_udf.alias("birth_datetime"), - "race_concept_id", - "gender_concept_id", - ) - - age_udf = F.ceil(F.months_between(F.col("visit_start_date"), F.col("birth_datetime")) / F.lit(12)) - visit_occurrence_person = ( - self._dependency_dict[VISIT_OCCURRENCE] - .join(patient_demographic, "person_id") - .withColumn("age", age_udf) - .drop("birth_datetime") - ) - - return create_sequence_data_with_att( - cohort_ehr_records, - visit_occurrence=visit_occurrence_person, - include_visit_type=self._include_visit_type, - exclude_visit_tokens=self._exclude_visit_tokens, - patient_demographic=(patient_demographic if self._gpt_patient_sequence else None), - att_type=self._att_type, - exclude_demographic=self._exclude_demographic, - use_age_group=self._use_age_group, - ) - - return create_sequence_data( - cohort_ehr_records, - date_filter=None, - include_visit_type=self._include_visit_type, - classic_bert_seq=self._classic_bert_seq, - ) - - @classmethod - def get_logger(cls): - return logging.getLogger(cls.__name__) - - -def create_prediction_cohort( - spark_args, - target_query_builder: QueryBuilder, - outcome_query_builder: QueryBuilder, - ehr_table_list, -): - """ - TODO. - - :param spark_args: - :param target_query_builder: - :param outcome_query_builder: - :param ehr_table_list: - :return: - """ - cohort_name = spark_args.cohort_name - input_folder = spark_args.input_folder - output_folder = spark_args.output_folder - date_lower_bound = spark_args.date_lower_bound - date_upper_bound = spark_args.date_upper_bound - age_lower_bound = spark_args.age_lower_bound - age_upper_bound = spark_args.age_upper_bound - observation_window = spark_args.observation_window - prediction_start_days = spark_args.prediction_start_days - prediction_window = spark_args.prediction_window - hold_off_window = spark_args.hold_off_window - num_of_visits = spark_args.num_of_visits - num_of_concepts = spark_args.num_of_concepts - include_visit_type = spark_args.include_visit_type - exclude_visit_tokens = spark_args.exclude_visit_tokens - is_feature_concept_frequency = spark_args.is_feature_concept_frequency - is_roll_up_concept = spark_args.is_roll_up_concept - is_window_post_index = spark_args.is_window_post_index - is_new_patient_representation = spark_args.is_new_patient_representation - is_hierarchical_bert = spark_args.is_hierarchical_bert - classic_bert_seq = spark_args.classic_bert_seq - is_first_time_outcome = spark_args.is_first_time_outcome - is_prediction_window_unbounded = spark_args.is_prediction_window_unbounded - is_observation_window_unbounded = spark_args.is_observation_window_unbounded - # If the outcome negative query exists, that means we need to remove those questionable - # outcomes from the target cohort - is_questionable_outcome_existed = outcome_query_builder.get_negative_query() is not None - - # Do we want to remove those records whose outcome occur between index_date and the start of - # the prediction window - is_remove_index_prediction_starts = spark_args.is_remove_index_prediction_starts - - # Toggle the prior/post observation_period depending on the is_window_post_index flag - prior_observation_period = 0 if is_window_post_index else observation_window + hold_off_window - post_observation_period = observation_window + hold_off_window if is_window_post_index else 0 - - # Generate the target cohort - target_cohort = ( - BaseCohortBuilder( - query_builder=target_query_builder, - input_folder=input_folder, - output_folder=output_folder, - date_lower_bound=date_lower_bound, - date_upper_bound=date_upper_bound, - age_lower_bound=age_lower_bound, - age_upper_bound=age_upper_bound, - prior_observation_period=prior_observation_period, - post_observation_period=post_observation_period, - ) - .build() - .load_cohort() - ) - - # Generate the outcome cohort - outcome_cohort = ( - BaseCohortBuilder( - query_builder=outcome_query_builder, - input_folder=input_folder, - output_folder=output_folder, - date_lower_bound=date_lower_bound, - date_upper_bound=date_upper_bound, - age_lower_bound=age_lower_bound, - age_upper_bound=age_upper_bound, - prior_observation_period=0, - post_observation_period=0, - ) - .build() - .load_cohort() - ) - - NestedCohortBuilder( - cohort_name=cohort_name, - input_folder=input_folder, - output_folder=output_folder, - patient_splits_folder=spark_args.patient_splits_folder, - target_cohort=target_cohort, - outcome_cohort=outcome_cohort, - ehr_table_list=ehr_table_list, - observation_window=observation_window, - hold_off_window=hold_off_window, - prediction_start_days=prediction_start_days, - prediction_window=prediction_window, - num_of_visits=num_of_visits, - num_of_concepts=num_of_concepts, - is_window_post_index=is_window_post_index, - include_visit_type=include_visit_type, - exclude_visit_tokens=exclude_visit_tokens, - allow_measurement_only=spark_args.allow_measurement_only, - is_feature_concept_frequency=is_feature_concept_frequency, - is_roll_up_concept=is_roll_up_concept, - include_concept_list=spark_args.include_concept_list, - is_new_patient_representation=is_new_patient_representation, - gpt_patient_sequence=spark_args.gpt_patient_sequence, - is_hierarchical_bert=is_hierarchical_bert, - classic_bert_seq=classic_bert_seq, - is_first_time_outcome=is_first_time_outcome, - is_questionable_outcome_existed=is_questionable_outcome_existed, - is_prediction_window_unbounded=is_prediction_window_unbounded, - is_remove_index_prediction_starts=is_remove_index_prediction_starts, - is_observation_window_unbounded=is_observation_window_unbounded, - is_population_estimation=spark_args.is_population_estimation, - att_type=AttType(spark_args.att_type), - exclude_demographic=spark_args.exclude_demographic, - use_age_group=spark_args.use_age_group, - single_contribution=spark_args.single_contribution, - ).build() diff --git a/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py b/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py deleted file mode 100644 index 3da47290..00000000 --- a/src/cehrbert/spark_apps/cohorts/type_two_diabietes.py +++ /dev/null @@ -1,170 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec - -COHORT_QUERY_TEMPLATE = """ -WITH person_ids_to_include_drug AS -( - SELECT DISTINCT - d.person_id - FROM global_temp.drug_exposure AS d - JOIN global_temp.{drug_inclusion_concepts} AS e - ON d.drug_concept_id = e.concept_id -), -person_ids_to_exclude_observation AS -( - - SELECT DISTINCT - o.person_id, - o.observation_date - FROM global_temp.observation AS o - JOIN global_temp.{observation_exclusion_concepts} AS oec - ON o.observation_concept_id = oec.concept_id -) -SELECT - distinct - c.person_id, - c.index_date, - c.visit_occurrence_id -FROM -( - SELECT DISTINCT - vo.person_id, - FIRST(DATE(vo.visit_start_date)) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS index_date, - FIRST(vo.visit_occurrence_id) OVER (PARTITION BY co.person_id - ORDER BY DATE(vo.visit_start_date), vo.visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.condition_occurrence AS co - JOIN global_temp.{diabetes_inclusion_concepts} AS ie - ON co.condition_concept_id = ie.concept_id - JOIN global_temp.visit_occurrence AS vo - ON co.visit_occurrence_id = vo.visit_occurrence_id -) c -JOIN person_ids_to_include_drug AS d - ON c.person_id = d.person_id -LEFT JOIN person_ids_to_exclude_observation AS eo - ON c.person_id = eo.person_id AND c.index_date > eo.observation_date -WHERE eo.person_id IS NULL AND c.index_date >= '{date_lower_bound}' -""" - -DIABETES_INCLUSION = [443238, 201820, 442793, 4016045] -DIABETES_EXCLUSION = [ - 40484648, - 201254, - 435216, - 4058243, - 30968, - 438476, - 195771, - 193323, - 4019513, - 40484649, -] -DRUG_INCLUSION = [ - 1503297, - 1594973, - 1597756, - 1559684, - 1560171, - 1502855, - 1502809, - 1525215, - 1547504, - 1580747, - 40166035, - 43013884, - 40239216, - 1516766, - 1502826, - 1510202, - 1529331, - 35605670, - 35602717, - 1516976, - 1502905, - 46221581, - 1550023, - 35198096, - 42899447, - 1544838, - 1567198, - 35884381, - 1531601, - 1588986, - 1513876, - 19013951, - 1590165, - 1596977, - 1586346, - 19090204, - 1513843, - 1513849, - 1562586, - 19090226, - 19090221, - 1586369, - 19090244, - 19090229, - 19090247, - 19090249, - 19090180, - 19013926, - 19091621, - 19090187, -] -OBSERVATION_EXCLUSION = [40769338, 43021173, 42539022, 46270562] -DEPENDENCY_LIST = [ - "person", - "condition_occurrence", - "visit_occurrence", - "drug_exposure", - "observation", -] - -DIABETES_INCLUSION_TABLE = "diabetes_inclusion_concepts" -DIABETES_EXCLUSION_TABLE = "diabetes_exclusion_concepts" -DRUG_INCLUSION_TABLE = "drug_inclusion_concepts" -OBSERVATION_EXCLUSION_TABLE = "observation_exclusion_concepts" - -DEFAULT_COHORT_NAME = "type_two_diabetes" - - -def query_builder(spark_args): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=COHORT_QUERY_TEMPLATE, - parameters={ - "diabetes_exclusion_concepts": DIABETES_EXCLUSION_TABLE, - "diabetes_inclusion_concepts": DIABETES_INCLUSION_TABLE, - "drug_inclusion_concepts": DRUG_INCLUSION_TABLE, - "observation_exclusion_concepts": OBSERVATION_EXCLUSION_TABLE, - "date_lower_bound": spark_args.date_lower_bound, - }, - ) - - ancestor_table_specs = [ - AncestorTableSpec( - table_name=DIABETES_INCLUSION_TABLE, - ancestor_concept_ids=DIABETES_INCLUSION, - is_standard=True, - ), - AncestorTableSpec( - table_name=DIABETES_EXCLUSION_TABLE, - ancestor_concept_ids=DIABETES_EXCLUSION, - is_standard=True, - ), - AncestorTableSpec( - table_name=OBSERVATION_EXCLUSION_TABLE, - ancestor_concept_ids=OBSERVATION_EXCLUSION, - is_standard=True, - ), - AncestorTableSpec( - table_name=DRUG_INCLUSION_TABLE, - ancestor_concept_ids=DRUG_INCLUSION, - is_standard=True, - ), - ] - return QueryBuilder( - cohort_name=DEFAULT_COHORT_NAME, - dependency_list=DEPENDENCY_LIST, - query=query, - ancestor_table_specs=ancestor_table_specs, - ) diff --git a/src/cehrbert/spark_apps/cohorts/ventilation.py b/src/cehrbert/spark_apps/cohorts/ventilation.py deleted file mode 100644 index d3af84d0..00000000 --- a/src/cehrbert/spark_apps/cohorts/ventilation.py +++ /dev/null @@ -1,22 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec - -VENTILATION_COHORT_QUERY = """ -SELECT DISTINCT - vent.person_id, - vent.earliest_placement_instant AS index_date, - CAST(NULL AS INT) AS visit_occurrence_id -FROM global_temp.vent AS vent -""" - -DEFAULT_COHORT_NAME = "ventilation" -DEPENDENCY_LIST = ["vent"] - - -def query_builder(): - query = QuerySpec( - table_name=DEFAULT_COHORT_NAME, - query_template=VENTILATION_COHORT_QUERY, - parameters={}, - ) - - return QueryBuilder(cohort_name=DEFAULT_COHORT_NAME, dependency_list=DEPENDENCY_LIST, query=query) diff --git a/src/cehrbert/spark_apps/decorators/__init__.py b/src/cehrbert/spark_apps/decorators/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py b/src/cehrbert/spark_apps/decorators/patient_event_decorator.py deleted file mode 100644 index de49ddd6..00000000 --- a/src/cehrbert/spark_apps/decorators/patient_event_decorator.py +++ /dev/null @@ -1,759 +0,0 @@ -import math -from abc import ABC, abstractmethod -from enum import Enum -from typing import Optional, Union - -import numpy as np -from pyspark.sql import DataFrame -from pyspark.sql import Window as W -from pyspark.sql import functions as F -from pyspark.sql import types as T - -from ...const.common import CATEGORICAL_MEASUREMENT, MEASUREMENT - - -class AttType(Enum): - DAY = "day" - WEEK = "week" - MONTH = "month" - CEHR_BERT = "cehr_bert" - MIX = "mix" - NONE = "none" - - -class PatientEventDecorator(ABC): - @abstractmethod - def _decorate(self, patient_events): - pass - - def decorate(self, patient_events): - decorated_patient_events = self._decorate(patient_events) - self.validate(decorated_patient_events) - return decorated_patient_events - - @classmethod - def get_required_columns(cls): - return set( - [ - "cohort_member_id", - "person_id", - "standard_concept_id", - "date", - "datetime", - "visit_occurrence_id", - "domain", - "concept_value", - "visit_rank_order", - "visit_segment", - "priority", - "date_in_week", - "concept_value_mask", - "mlm_skip_value", - "age", - "visit_concept_id", - "visit_start_date", - "visit_start_datetime", - "visit_concept_order", - "concept_order", - ] - ) - - def validate(self, patient_events: DataFrame): - actual_column_set = set(patient_events.columns) - expected_column_set = set(self.get_required_columns()) - if actual_column_set != expected_column_set: - diff_left = actual_column_set - expected_column_set - diff_right = expected_column_set - actual_column_set - raise RuntimeError( - f"{self}\n" - f"actual_column_set - expected_column_set: {diff_left}\n" - f"expected_column_set - actual_column_set: {diff_right}" - ) - - -class PatientEventBaseDecorator(PatientEventDecorator): - # output_columns = [ - # 'cohort_member_id', 'person_id', 'concept_ids', 'visit_segments', 'orders', - # 'dates', 'ages', 'visit_concept_orders', 'num_of_visits', 'num_of_concepts', - # 'concept_value_masks', 'concept_values', 'mlm_skip_values', - # 'visit_concept_ids' - # ] - def __init__(self, visit_occurrence): - self._visit_occurrence = visit_occurrence - - def _decorate(self, patient_events: DataFrame): - """ - Patient_events contains the following columns (cohort_member_id, person_id,. - - standard_concept_id, date, visit_occurrence_id, domain, concept_value) - - :param patient_events: - :return: - """ - - # todo: create an assertion the dataframe contains the above columns - - valid_visit_ids = patient_events.select("visit_occurrence_id", "cohort_member_id").distinct() - - # Add visit_start_date to the patient_events dataframe and create the visit rank - visit_rank_udf = F.row_number().over( - W.partitionBy("person_id", "cohort_member_id").orderBy( - "visit_start_datetime", "is_inpatient", "expired", "visit_occurrence_id" - ) - ) - visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 - - # The visit records are joined to the cohort members (there could be multiple entries for the same patient) - # if multiple entries are present, we duplicate the visit records for those. - visits = ( - self._visit_occurrence.join(valid_visit_ids, "visit_occurrence_id") - .select( - "person_id", - "cohort_member_id", - "visit_occurrence_id", - "visit_end_date", - F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"), - F.to_timestamp("visit_start_datetime").alias("visit_start_datetime"), - F.col("visit_concept_id").cast("int").isin([9201, 262, 8971, 8920]).cast("int").alias("is_inpatient"), - F.when(F.col("discharged_to_concept_id").cast("int") == 4216643, F.lit(1)) - .otherwise(F.lit(0)) - .alias("expired"), - ) - .withColumn("visit_rank_order", visit_rank_udf) - .withColumn("visit_segment", visit_segment_udf) - .drop("person_id", "expired") - ) - - # Determine the concept order depending on the visit type. For outpatient visits, we assume the concepts to - # have the same order, whereas for inpatient visits, the concept order is determined by the time stamp. - # the concept order needs to be generated per each cohort member because the same visit could be used - # in multiple cohort's histories of the same patient - concept_order_udf = F.when( - F.col("is_inpatient") == 1, - F.dense_rank().over(W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("datetime")), - ).otherwise(F.lit(1)) - - # Determine the global visit concept order for each patient, this takes both visit_rank_order and concept_order - # into account when assigning this new order. - # e.g. visit_rank_order = [1, 1, 2, 2], concept_order = [1, 1, 1, 2] -> visit_concept_order = [1, 1, 2, 3] - visit_concept_order_udf = F.dense_rank().over( - W.partitionBy("person_id", "cohort_member_id").orderBy("visit_rank_order", "concept_order") - ) - - # We need to set the visit_end_date as the visit_start_date for outpatient visits - # For inpatient visits, we use the original visit_end_date if available, otherwise - # we will infer the visit_end_date using the max(date) of the current visit - visit_end_date_udf = F.when( - F.col("is_inpatient") == 1, - F.coalesce( - F.col("visit_end_date"), - F.max("date").over(W.partitionBy("cohort_member_id", "visit_occurrence_id")), - ), - ).otherwise(F.col("visit_start_date")) - - # We need to bound the medical event dates between visit_start_date and visit_end_date - bound_medical_event_date = F.when( - F.col("date") < F.col("visit_start_date"), F.col("visit_start_date") - ).otherwise(F.when(F.col("date") > F.col("visit_end_date"), F.col("visit_end_date")).otherwise(F.col("date"))) - - # We need to bound the medical event dates between visit_start_date and visit_end_date - bound_medical_event_datetime = F.when( - F.col("datetime") < F.col("visit_start_datetime"), - F.col("visit_start_datetime"), - ).otherwise( - F.when( - F.col("datetime") > F.col("visit_end_datetime"), - F.col("visit_end_datetime"), - ).otherwise(F.col("datetime")) - ) - - patient_events = ( - patient_events.join(visits, ["cohort_member_id", "visit_occurrence_id"]) - .withColumn("visit_end_date", visit_end_date_udf) - .withColumn("visit_end_datetime", F.date_add("visit_end_date", 1)) - .withColumn("visit_end_datetime", F.expr("visit_end_datetime - INTERVAL 1 MINUTE")) - .withColumn("date", bound_medical_event_date) - .withColumn("datetime", bound_medical_event_datetime) - .withColumn("concept_order", concept_order_udf) - .withColumn("visit_concept_order", visit_concept_order_udf) - .drop("is_inpatient", "visit_end_date", "visit_end_datetime") - .distinct() - ) - - # Set the priority for the events. - # Create the week since epoch UDF - weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") - patient_events = patient_events.withColumn("priority", F.lit(0)).withColumn( - "date_in_week", weeks_since_epoch_udf - ) - - # Create the concept_value_mask field to indicate whether domain values should be skipped - # As of now only measurement has values, so other domains would be skipped. - patient_events = patient_events.withColumn( - "concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int") - ).withColumn( - "mlm_skip_value", - (F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"), - ) - - if "concept_value" not in patient_events.schema.fieldNames(): - patient_events = patient_events.withColumn("concept_value", F.lit(0.0)) - - # (cohort_member_id, person_id, standard_concept_id, date, datetime, visit_occurrence_id, domain, - # concept_value, visit_rank_order, visit_segment, priority, date_in_week, - # concept_value_mask, mlm_skip_value, age) - return patient_events - - -class PatientEventAttDecorator(PatientEventDecorator): - def __init__( - self, - visit_occurrence, - include_visit_type, - exclude_visit_tokens, - att_type: AttType, - include_inpatient_hour_token: bool = False, - ): - self._visit_occurrence = visit_occurrence - self._include_visit_type = include_visit_type - self._exclude_visit_tokens = exclude_visit_tokens - self._att_type = att_type - self._include_inpatient_hour_token = include_inpatient_hour_token - - def _decorate(self, patient_events: DataFrame): - if self._att_type == AttType.NONE: - return patient_events - - # visits should the following columns (person_id, - # visit_concept_id, visit_start_date, visit_occurrence_id, domain, concept_value) - cohort_member_person_pair = patient_events.select("person_id", "cohort_member_id").distinct() - valid_visit_ids = patient_events.groupby( - "cohort_member_id", - "visit_occurrence_id", - "visit_segment", - "visit_rank_order", - ).agg( - F.min("visit_concept_order").alias("min_visit_concept_order"), - F.max("visit_concept_order").alias("max_visit_concept_order"), - F.min("concept_order").alias("min_concept_order"), - F.max("concept_order").alias("max_concept_order"), - ) - - visit_occurrence = ( - self._visit_occurrence.select( - "person_id", - F.col("visit_start_date").cast(T.DateType()).alias("date"), - F.col("visit_start_date").cast(T.DateType()).alias("visit_start_date"), - F.col("visit_start_datetime").cast(T.TimestampType()).alias("visit_start_datetime"), - F.coalesce("visit_end_date", "visit_start_date").cast(T.DateType()).alias("visit_end_date"), - "visit_concept_id", - "visit_occurrence_id", - F.lit("visit").alias("domain"), - F.lit(0.0).alias("concept_value"), - F.lit(0).alias("concept_value_mask"), - F.lit(0).alias("mlm_skip_value"), - "age", - "discharged_to_concept_id", - ) - .join(valid_visit_ids, "visit_occurrence_id") - .join(cohort_member_person_pair, ["person_id", "cohort_member_id"]) - ) - - # We assume outpatient visits end on the same day, therefore we start visit_end_date to visit_start_date due - # to bad date - visit_occurrence = visit_occurrence.withColumn( - "visit_end_date", - F.when( - F.col("visit_concept_id").isin([9201, 262, 8971, 8920]), - F.col("visit_end_date"), - ).otherwise(F.col("visit_start_date")), - ) - - weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") - visit_occurrence = visit_occurrence.withColumn("date_in_week", weeks_since_epoch_udf) - - # Cache visit for faster processing - visit_occurrence.cache() - - visits = visit_occurrence.drop("discharged_to_concept_id") - - # (cohort_member_id, person_id, standard_concept_id, date, visit_occurrence_id, domain, - # concept_value, visit_rank_order, visit_segment, priority, date_in_week, - # concept_value_mask, mlm_skip_value, visit_end_date) - visit_start_events = ( - visits.withColumn("date", F.col("visit_start_date")) - .withColumn("datetime", F.to_timestamp("visit_start_date")) - .withColumn("standard_concept_id", F.lit("VS")) - .withColumn("visit_concept_order", F.col("min_visit_concept_order")) - .withColumn("concept_order", F.col("min_concept_order") - 1) - .withColumn("priority", F.lit(-2)) - .drop("min_visit_concept_order", "max_visit_concept_order") - .drop("min_concept_order", "max_concept_order") - ) - - visit_end_events = ( - visits.withColumn("date", F.col("visit_end_date")) - .withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1)) - .withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE")) - .withColumn("standard_concept_id", F.lit("VE")) - .withColumn("visit_concept_order", F.col("max_visit_concept_order")) - .withColumn("concept_order", F.col("max_concept_order") + 1) - .withColumn("priority", F.lit(200)) - .drop("min_visit_concept_order", "max_visit_concept_order") - .drop("min_concept_order", "max_concept_order") - ) - - # Get the prev days_since_epoch - prev_visit_end_date_udf = F.lag("visit_end_date").over( - W.partitionBy("person_id", "cohort_member_id").orderBy("visit_rank_order") - ) - - # Compute the time difference between the current record and the previous record - time_delta_udf = F.when(F.col("prev_visit_end_date").isNull(), 0).otherwise( - F.datediff("visit_start_date", "prev_visit_end_date") - ) - - # Udf for calculating the time token - if self._att_type == AttType.DAY: - att_func = time_day_token - elif self._att_type == AttType.WEEK: - att_func = time_week_token - elif self._att_type == AttType.MONTH: - att_func = time_month_token - elif self._att_type == AttType.MIX: - att_func = time_mix_token - else: - att_func = time_token_func - - time_token_udf = F.udf(att_func, T.StringType()) - - att_tokens = ( - visits.withColumn("datetime", F.to_timestamp("date")) - .withColumn("prev_visit_end_date", prev_visit_end_date_udf) - .where(F.col("prev_visit_end_date").isNotNull()) - .withColumn("time_delta", time_delta_udf) - .withColumn( - "time_delta", - F.when(F.col("time_delta") < 0, F.lit(0)).otherwise(F.col("time_delta")), - ) - .withColumn("standard_concept_id", time_token_udf("time_delta")) - .withColumn("priority", F.lit(-3)) - .withColumn("visit_rank_order", F.col("visit_rank_order")) - .withColumn("visit_concept_order", F.col("min_visit_concept_order")) - .withColumn("concept_order", F.lit(0)) - .drop("prev_visit_end_date", "time_delta") - .drop("min_visit_concept_order", "max_visit_concept_order") - .drop("min_concept_order", "max_concept_order") - ) - - if self._exclude_visit_tokens: - artificial_tokens = att_tokens - else: - artificial_tokens = visit_start_events.unionByName(att_tokens).unionByName(visit_end_events) - - if self._include_visit_type: - # insert visit type after the VS token - visit_type_tokens = ( - visits.withColumn("standard_concept_id", F.col("visit_concept_id")) - .withColumn("datetime", F.to_timestamp("date")) - .withColumn("visit_concept_order", F.col("min_visit_concept_order")) - .withColumn("concept_order", F.lit(0)) - .withColumn("priority", F.lit(-1)) - .drop("min_visit_concept_order", "max_visit_concept_order") - .drop("min_concept_order", "max_concept_order") - ) - - artificial_tokens = artificial_tokens.unionByName(visit_type_tokens) - - artificial_tokens = artificial_tokens.drop("visit_end_date") - - # Retrieving the events that are ONLY linked to inpatient visits - inpatient_visits = visit_occurrence.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920])).select( - "visit_occurrence_id", "visit_end_date", "cohort_member_id" - ) - inpatient_events = patient_events.join(inpatient_visits, ["visit_occurrence_id", "cohort_member_id"]) - - # Fill in the visit_end_date if null (because some visits are still ongoing at the time of data extraction) - # Bound the event dates within visit_start_date and visit_end_date - # Generate a span rank to indicate the position of the group of events - # Update the priority for each span - inpatient_events = ( - inpatient_events.withColumn( - "visit_end_date", - F.coalesce( - "visit_end_date", - F.max("date").over(W.partitionBy("cohort_member_id", "visit_occurrence_id")), - ), - ) - .withColumn( - "date", - F.when(F.col("date") < F.col("visit_start_date"), F.col("visit_start_date")).otherwise( - F.when(F.col("date") > F.col("visit_end_date"), F.col("visit_end_date")).otherwise(F.col("date")) - ), - ) - .withColumn("priority", F.col("priority") + F.col("concept_order") * 0.1) - .drop("visit_end_date") - ) - - discharge_events = ( - visit_occurrence.where(F.col("visit_concept_id").isin([9201, 262, 8971, 8920])) - .withColumn( - "standard_concept_id", - F.coalesce(F.col("discharged_to_concept_id"), F.lit(0)), - ) - .withColumn("visit_concept_order", F.col("max_visit_concept_order")) - .withColumn("concept_order", F.col("max_concept_order") + 1) - .withColumn("date", F.col("visit_end_date")) - .withColumn("datetime", F.date_add(F.to_timestamp("visit_end_date"), 1)) - .withColumn("datetime", F.expr("datetime - INTERVAL 1 MINUTE")) - .withColumn("priority", F.lit(100)) - .drop("discharged_to_concept_id", "visit_end_date") - .drop("min_visit_concept_order", "max_visit_concept_order") - .drop("min_concept_order", "max_concept_order") - ) - - # Add discharge events to the inpatient visits - inpatient_events = inpatient_events.unionByName(discharge_events) - - # Get the prev days_since_epoch - inpatient_prev_date_udf = F.lag("date").over( - W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order") - ) - - # Compute the time difference between the current record and the previous record - inpatient_time_delta_udf = F.when(F.col("prev_date").isNull(), 0).otherwise(F.datediff("date", "prev_date")) - - if self._include_inpatient_hour_token: - # Create ATT tokens within the inpatient visits - inpatient_prev_datetime_udf = F.lag("datetime").over( - W.partitionBy("cohort_member_id", "visit_occurrence_id").orderBy("concept_order") - ) - # Compute the time difference between the current record and the previous record - inpatient_hour_delta_udf = F.when(F.col("prev_datetime").isNull(), 0).otherwise( - F.floor((F.unix_timestamp("datetime") - F.unix_timestamp("prev_datetime")) / 3600) - ) - inpatient_att_token = F.when( - F.col("hour_delta") < 24, F.concat(F.lit("i-H"), F.col("hour_delta")) - ).otherwise(F.concat(F.lit("i-"), time_token_udf("time_delta"))) - # Create ATT tokens within the inpatient visits - inpatient_att_events = ( - inpatient_events.withColumn( - "is_span_boundary", - F.row_number().over( - W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority") - ), - ) - .where(F.col("is_span_boundary") == 1) - .withColumn("prev_date", inpatient_prev_date_udf) - .withColumn("time_delta", inpatient_time_delta_udf) - .withColumn("prev_datetime", inpatient_prev_datetime_udf) - .withColumn("hour_delta", inpatient_hour_delta_udf) - .where(F.col("prev_date").isNotNull()) - .where(F.col("hour_delta") > 0) - .withColumn("standard_concept_id", inpatient_att_token) - .withColumn("visit_concept_order", F.col("visit_concept_order")) - .withColumn("priority", F.col("priority") - 0.01) - .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) - .drop("prev_date", "time_delta", "is_span_boundary") - .drop("prev_datetime", "hour_delta") - ) - else: - # Create ATT tokens within the inpatient visits - inpatient_att_events = ( - inpatient_events.withColumn( - "is_span_boundary", - F.row_number().over( - W.partitionBy("cohort_member_id", "visit_occurrence_id", "concept_order").orderBy("priority") - ), - ) - .where(F.col("is_span_boundary") == 1) - .withColumn("prev_date", inpatient_prev_date_udf) - .withColumn("time_delta", inpatient_time_delta_udf) - .where(F.col("time_delta") != 0) - .where(F.col("prev_date").isNotNull()) - .withColumn( - "standard_concept_id", - F.concat(F.lit("i-"), time_token_udf("time_delta")), - ) - .withColumn("visit_concept_order", F.col("visit_concept_order")) - .withColumn("priority", F.col("priority") - 0.01) - .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) - .drop("prev_date", "time_delta", "is_span_boundary") - ) - - self.validate(inpatient_events) - self.validate(inpatient_att_events) - - # Retrieving the events that are NOT linked to inpatient visits - other_events = patient_events.join( - inpatient_visits.select("visit_occurrence_id", "cohort_member_id"), - ["visit_occurrence_id", "cohort_member_id"], - how="left_anti", - ) - - patient_events = inpatient_events.unionByName(inpatient_att_events).unionByName(other_events) - - self.validate(patient_events) - self.validate(artificial_tokens) - - # artificial_tokens = artificial_tokens.select(sorted(artificial_tokens.columns)) - return patient_events.unionByName(artificial_tokens) - - -class DemographicPromptDecorator(PatientEventDecorator): - def __init__(self, patient_demographic, use_age_group: bool = False): - self._patient_demographic = patient_demographic - self._use_age_group = use_age_group - - def _decorate(self, patient_events: DataFrame): - if self._patient_demographic is None: - return patient_events - - # set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', - # 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', - # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', - # 'mlm_skip_value', 'age', 'visit_concept_id']) - - # Get the first token of the patient history - first_token_udf = F.row_number().over( - W.partitionBy("cohort_member_id", "person_id").orderBy( - "visit_start_datetime", - "visit_occurrence_id", - "priority", - "standard_concept_id", - ) - ) - - # Identify the first token of each patient history - patient_first_token = ( - patient_events.withColumn("token_order", first_token_udf) - .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(0.0)) - .where("token_order = 1") - .drop("token_order") - ) - - # Udf for identifying the earliest date associated with a visit_occurrence_id - sequence_start_year_token = ( - patient_first_token.withColumn( - "standard_concept_id", - F.concat(F.lit("year:"), F.year("date").cast(T.StringType())), - ) - .withColumn("priority", F.lit(-10)) - .withColumn("visit_segment", F.lit(0)) - .withColumn("date_in_week", F.lit(0)) - .withColumn("age", F.lit(-1)) - .withColumn("visit_rank_order", F.lit(0)) - .withColumn("visit_concept_order", F.lit(0)) - .withColumn("concept_order", F.lit(0)) - ) - - sequence_start_year_token.cache() - - if self._use_age_group: - calculate_age_group_at_first_visit_udf = F.ceil( - F.floor(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12) / 10) - ) - age_at_first_visit_udf = F.concat( - F.lit("age:"), - (calculate_age_group_at_first_visit_udf * 10).cast(T.StringType()), - F.lit("-"), - ((calculate_age_group_at_first_visit_udf + 1) * 10).cast(T.StringType()), - ) - else: - calculate_age_at_first_visit_udf = F.ceil( - F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12) - ) - age_at_first_visit_udf = F.concat(F.lit("age:"), calculate_age_at_first_visit_udf.cast(T.StringType())) - - sequence_age_token = ( - self._patient_demographic.select(F.col("person_id"), F.col("birth_datetime")) - .join(sequence_start_year_token, "person_id") - .withColumn("standard_concept_id", age_at_first_visit_udf) - .withColumn("priority", F.lit(-9)) - .drop("birth_datetime") - ) - - sequence_gender_token = ( - self._patient_demographic.select(F.col("person_id"), F.col("gender_concept_id")) - .join(sequence_start_year_token, "person_id") - .withColumn("standard_concept_id", F.col("gender_concept_id").cast(T.StringType())) - .withColumn("priority", F.lit(-8)) - .drop("gender_concept_id") - ) - - sequence_race_token = ( - self._patient_demographic.select(F.col("person_id"), F.col("race_concept_id")) - .join(sequence_start_year_token, "person_id") - .withColumn("standard_concept_id", F.col("race_concept_id").cast(T.StringType())) - .withColumn("priority", F.lit(-7)) - .drop("race_concept_id") - ) - - patient_events = patient_events.unionByName(sequence_start_year_token) - patient_events = patient_events.unionByName(sequence_age_token) - patient_events = patient_events.unionByName(sequence_gender_token) - patient_events = patient_events.unionByName(sequence_race_token) - - return patient_events - - -class DeathEventDecorator(PatientEventDecorator): - def __init__(self, death, att_type): - self._death = death - self._att_type = att_type - - def _decorate(self, patient_events: DataFrame): - if self._death is None: - return patient_events - - death_records = patient_events.join(self._death.select("person_id", "death_date"), "person_id") - - max_visit_occurrence_id = death_records.select(F.max("visit_occurrence_id").alias("max_visit_occurrence_id")) - - last_ve_record = ( - death_records.where(F.col("standard_concept_id") == "VE") - .withColumn( - "record_rank", - F.row_number().over(W.partitionBy("person_id", "cohort_member_id").orderBy(F.desc("date"))), - ) - .where(F.col("record_rank") == 1) - .drop("record_rank") - ) - - last_ve_record.cache() - last_ve_record.show() - # set(['cohort_member_id', 'person_id', 'standard_concept_id', 'date', - # 'visit_occurrence_id', 'domain', 'concept_value', 'visit_rank_order', - # 'visit_segment', 'priority', 'date_in_week', 'concept_value_mask', - # 'mlm_skip_value', 'age', 'visit_concept_id']) - - artificial_visit_id = F.row_number().over( - W.partitionBy(F.lit(0)).orderBy("person_id", "cohort_member_id") - ) + F.col("max_visit_occurrence_id") - death_records = ( - last_ve_record.crossJoin(max_visit_occurrence_id) - .withColumn("visit_occurrence_id", artificial_visit_id) - .withColumn("standard_concept_id", F.lit("[DEATH]")) - .withColumn("domain", F.lit("death")) - .withColumn("visit_rank_order", F.lit(1) + F.col("visit_rank_order")) - .withColumn("priority", F.lit(20)) - .drop("max_visit_occurrence_id") - ) - - vs_records = death_records.withColumn("standard_concept_id", F.lit("VS")).withColumn("priority", F.lit(15)) - - ve_records = death_records.withColumn("standard_concept_id", F.lit("VE")).withColumn("priority", F.lit(30)) - - # Udf for calculating the time token - if self._att_type == AttType.DAY: - att_func = time_day_token - elif self._att_type == AttType.WEEK: - att_func = time_week_token - elif self._att_type == AttType.MONTH: - att_func = time_month_token - elif self._att_type == AttType.MIX: - att_func = time_mix_token - else: - att_func = time_token_func - - time_token_udf = F.udf(att_func, T.StringType()) - - att_records = death_records.withColumn( - "death_date", - F.when(F.col("death_date") < F.col("date"), F.col("date")).otherwise(F.col("death_date")), - ) - att_records = ( - att_records.withColumn("time_delta", F.datediff("death_date", "date")) - .withColumn("standard_concept_id", time_token_udf("time_delta")) - .withColumn("priority", F.lit(10)) - .drop("time_delta") - ) - - new_tokens = att_records.unionByName(vs_records).unionByName(death_records).unionByName(ve_records) - new_tokens = new_tokens.drop("death_date") - self.validate(new_tokens) - - return patient_events.unionByName(new_tokens) - - -def time_token_func(time_delta) -> Optional[str]: - if time_delta is None or np.isnan(time_delta): - return None - if time_delta < 0: - return "W-1" - if time_delta < 28: - return f"W{str(math.floor(time_delta / 7))}" - if time_delta < 360: - return f"M{str(math.floor(time_delta / 30))}" - return "LT" - - -def time_day_token(time_delta): - if time_delta is None or np.isnan(time_delta): - return None - if time_delta < 1080: - return f"D{str(time_delta)}" - return "LT" - - -def time_week_token(time_delta): - if time_delta is None or np.isnan(time_delta): - return None - if time_delta < 1080: - return f"W{str(math.floor(time_delta / 7))}" - return "LT" - - -def time_month_token(time_delta): - if time_delta is None or np.isnan(time_delta): - return None - if time_delta < 1080: - return f"M{str(math.floor(time_delta / 30))}" - return "LT" - - -def time_mix_token(time_delta): - # WHEN day_diff <= 7 THEN CONCAT('D', day_diff) - # WHEN day_diff <= 30 THEN CONCAT('W', ceil(day_diff / 7)) - # WHEN day_diff <= 360 THEN CONCAT('M', ceil(day_diff / 30)) - # WHEN day_diff <= 720 THEN CONCAT('Q', ceil(day_diff / 90)) - # WHEN day_diff <= 1440 THEN CONCAT('Y', ceil(day_diff / 360)) - # ELSE 'LT' - if time_delta is None or np.isnan(time_delta): - return None - if time_delta <= 7: - return f"D{str(time_delta)}" - if time_delta <= 30: - # e.g. 8 -> W2 - return f"W{str(math.ceil(time_delta / 7))}" - if time_delta <= 360: - # e.g. 31 -> M2 - return f"M{str(math.ceil(time_delta / 30))}" - # if time_delta <= 720: - # # e.g. 361 -> Q5 - # return f'Q{str(math.ceil(time_delta / 90))}' - # if time_delta <= 1080: - # # e.g. 1081 -> Y2 - # return f'Y{str(math.ceil(time_delta / 360))}' - return "LT" - - -def get_att_function(att_type: Union[AttType, str]): - # Convert the att_type str to the corresponding enum type - if isinstance(att_type, str): - att_type = AttType(att_type) - - if att_type == AttType.DAY: - return time_day_token - elif att_type == AttType.WEEK: - return time_week_token - elif att_type == AttType.MONTH: - return time_month_token - elif att_type == AttType.MIX: - return time_mix_token - elif att_type == AttType.CEHR_BERT: - return time_token_func - return None diff --git a/src/cehrbert/spark_apps/generate_concept_similarity_table.py b/src/cehrbert/spark_apps/generate_concept_similarity_table.py deleted file mode 100644 index 339d4880..00000000 --- a/src/cehrbert/spark_apps/generate_concept_similarity_table.py +++ /dev/null @@ -1,423 +0,0 @@ -"""This module provides functionality to extract patient event data from domain tables,. - -compute information content and semantic similarity for concepts, and calculate concept -similarity scores. - -Functions: extract_data: Extract data from specified domain tables. compute_information_content: -Compute the information content for concepts based on frequency. -compute_information_content_similarity: Compute the similarity between concepts based on -information content. compute_semantic_similarity: Compute the semantic similarity between concept -pairs. main: Main function to orchestrate the extraction, processing, and saving of concept -similarity data. -""" - -import datetime -import logging -import os -from typing import List - -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql import functions as F - -from ..config.output_names import CONCEPT_SIMILARITY_PATH, QUALIFIED_CONCEPT_LIST_PATH -from ..const.common import CONCEPT, CONCEPT_ANCESTOR -from ..utils.spark_utils import join_domain_tables, preprocess_domain_table, validate_table_names - - -def extract_data(spark: SparkSession, input_folder: str, domain_table_list: List[str]): - """ - Extract patient event data from the specified domain tables. - - Args: - spark (SparkSession): The Spark session to use for processing. - input_folder (str): Path to the input folder containing domain tables. - domain_table_list (List[str]): List of domain table names to extract data from. - - Returns: - DataFrame: A DataFrame containing extracted and processed patient event data. - """ - domain_tables = [] - for domain_table_name in domain_table_list: - domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name)) - patient_event = join_domain_tables(domain_tables) - # Remove all concept_id records - patient_event = patient_event.where("standard_concept_id <> 0") - - return patient_event - - -def compute_information_content(patient_event: DataFrame, concept_ancestor: DataFrame): - """ - Calculate the information content using the frequency of each concept and the graph. - - :param patient_event: - :param concept_ancestor: - :return: - """ - # Get the total count - total_count = patient_event.distinct().count() - # Count the frequency of each concept - concept_frequency = patient_event.distinct().groupBy("standard_concept_id").count() - # left join b/w descendent_concept_id and the standard_concept_id in the concept freq table - freq_df = ( - concept_frequency.join( - concept_ancestor, - F.col("descendant_concept_id") == F.col("standard_concept_id"), - ) - .groupBy("ancestor_concept_id") - .sum("count") - .withColumnRenamed("ancestor_concept_id", "concept_id") - .withColumnRenamed("sum(count)", "count") - ) - # Calculate information content for each concept - information_content = freq_df.withColumn("information_content", (-F.log(F.col("count") / total_count))).withColumn( - "probability", F.col("count") / total_count - ) - - return information_content - - -def compute_information_content_similarity( - concept_pair: DataFrame, information_content: DataFrame, concept_ancestor: DataFrame -): - """ - Compute the similarity between concept pairs based on their information content. - - Args: - concept_pair (DataFrame): A DataFrame containing pairs of concepts. - information_content (DataFrame): A DataFrame with information content for concepts. - concept_ancestor (DataFrame): A DataFrame containing concept ancestor relationships. - - Returns: - DataFrame: A DataFrame containing various similarity measures for concept pairs. - """ - # Extract the pairs of concepts from the training data and join to the information content table - information_content_concept_pair = ( - concept_pair.select("concept_id_1", "concept_id_2") - .join( - information_content, - F.col("concept_id_1") == F.col("concept_id"), - "left_outer", - ) - .select( - F.col("concept_id_1"), - F.col("concept_id_2"), - F.col("information_content").alias("information_content_1"), - ) - .join( - information_content, - F.col("concept_id_2") == F.col("concept_id"), - "left_outer", - ) - .select( - F.col("concept_id_1"), - F.col("concept_id_2"), - F.col("information_content_1"), - F.col("information_content").alias("information_content_2"), - ) - ) - - # Join to get all the ancestors of concept_id_1 - concept_id_1_ancestor = information_content_concept_pair.join( - concept_ancestor, F.col("concept_id_1") == F.col("descendant_concept_id") - ).select("concept_id_1", "concept_id_2", "ancestor_concept_id") - - # Join to get all the ancestors of concept_id_2 - concept_id_2_ancestor = concept_pair.join( - concept_ancestor, F.col("concept_id_2") == F.col("descendant_concept_id") - ).select("concept_id_1", "concept_id_2", "ancestor_concept_id") - - # Compute the summed information content of all ancestors of concept_id_1 and concept_id_2 - union_sum = ( - concept_id_1_ancestor.union(concept_id_2_ancestor) - .distinct() - .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) - .groupBy("concept_id_1", "concept_id_2") - .agg(F.sum("information_content").alias("ancestor_union_ic")) - ) - - # Compute the summed information content of common ancestors of concept_id_1 and concept_id_2 - intersection_sum = ( - concept_id_1_ancestor.intersect(concept_id_2_ancestor) - .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) - .groupBy("concept_id_1", "concept_id_2") - .agg(F.sum("information_content").alias("ancestor_intersection_ic")) - ) - - # Compute the information content and probability of the most informative common ancestor (MICA) - mica_ancestor = ( - concept_id_1_ancestor.intersect(concept_id_2_ancestor) - .join(information_content, F.col("ancestor_concept_id") == F.col("concept_id")) - .groupBy("concept_id_1", "concept_id_2") - .agg( - F.max("information_content").alias("mica_information_content"), - F.max("probability").alias("mica_probability"), - ) - ) - - # Join the MICA to pairs of concepts - features = information_content_concept_pair.join( - mica_ancestor, - (information_content_concept_pair["concept_id_1"] == mica_ancestor["concept_id_1"]) - & (information_content_concept_pair["concept_id_2"] == mica_ancestor["concept_id_2"]), - "left_outer", - ).select( - [information_content_concept_pair[f] for f in information_content_concept_pair.schema.fieldNames()] - + [F.col("mica_information_content"), F.col("mica_probability")] - ) - - # Compute the lin measure - features = features.withColumn( - "lin_measure", - 2 * F.col("mica_information_content") / (F.col("information_content_1") * F.col("information_content_2")), - ) - - # Compute the jiang measure - features = features.withColumn( - "jiang_measure", - 1 - (F.col("information_content_1") + F.col("information_content_2") - 2 * F.col("mica_information_content")), - ) - - # Compute the information coefficient - features = features.withColumn( - "information_coefficient", - F.col("lin_measure") * (1 - 1 / (1 + F.col("mica_information_content"))), - ) - - # Compute the relevance_measure - features = features.withColumn("relevance_measure", F.col("lin_measure") * (1 - F.col("mica_probability"))) - - # Join to get the summed information content of the common ancestors of concept_id_1 and - # concept_id_2 - features = features.join( - intersection_sum, - (features["concept_id_1"] == intersection_sum["concept_id_1"]) - & (features["concept_id_2"] == intersection_sum["concept_id_2"]), - "left_outer", - ).select([features[f] for f in features.schema.fieldNames()] + [F.col("ancestor_intersection_ic")]) - - # Join to get the summed information content of the common ancestors of concept_id_1 and - # concept_id_2 - features = features.join( - union_sum, - (features["concept_id_1"] == union_sum["concept_id_1"]) - & (features["concept_id_2"] == union_sum["concept_id_2"]), - "left_outer", - ).select([features[f] for f in features.schema.fieldNames()] + [F.col("ancestor_union_ic")]) - - # Compute the graph information content measure - features = features.withColumn( - "graph_ic_measure", - F.col("ancestor_intersection_ic") / F.col("ancestor_union_ic"), - ) - - return features.select( - [ - F.col("concept_id_1"), - F.col("concept_id_2"), - F.col("mica_information_content"), - F.col("lin_measure"), - F.col("jiang_measure"), - F.col("information_coefficient"), - F.col("relevance_measure"), - F.col("graph_ic_measure"), - ] - ) - - -def compute_semantic_similarity(spark, patient_event, concept, concept_ancestor): - required_concept = ( - patient_event.distinct() - .select("standard_concept_id") - .join(concept, F.col("standard_concept_id") == F.col("concept_id")) - .select("standard_concept_id", "domain_id") - ) - - concept_ancestor.createOrReplaceTempView("concept_ancestor") - required_concept.createOrReplaceTempView("required_concept") - - concept_pair = spark.sql( - """ - WITH concept_pair AS ( - SELECT - c1.standard_concept_id AS concept_id_1, - c2.standard_concept_id AS concept_id_2, - c1.domain_id - FROM required_concept AS c1 - JOIN required_concept AS c2 - ON c1.domain_id = c2.domain_id - WHERE c1.standard_concept_id <> c2.standard_concept_id - ) - SELECT - cp.concept_id_1, - cp.concept_id_2, - ca_1.ancestor_concept_id AS common_ancestor_concept_id, - ca_1.min_levels_of_separation AS distance_1, - ca_2.min_levels_of_separation AS distance_2 - FROM concept_pair AS cp - JOIN concept_ancestor AS ca_1 - ON cp.concept_id_1 = ca_1.descendant_concept_id - JOIN concept_ancestor AS ca_2 - ON cp.concept_id_2 = ca_2.descendant_concept_id - WHERE ca_1.ancestor_concept_id = ca_2.ancestor_concept_id - """ - ) - - # Find the root concepts - root_concept = ( - concept_ancestor.groupBy("descendant_concept_id") - .count() - .where("count = 1") - .withColumnRenamed("descendant_concept_id", "root_concept_id") - ) - # Retrieve all ancestor descendant relationships for the root concepts - root_concept_relationship = ( - root_concept.join( - concept_ancestor, - root_concept["root_concept_id"] == concept_ancestor["ancestor_concept_id"], - ) - .select( - concept_ancestor["ancestor_concept_id"], - concept_ancestor["descendant_concept_id"], - concept_ancestor["max_levels_of_separation"].alias("root_distance"), - ) - .where("ancestor_concept_id <> descendant_concept_id") - ) - - # Join to get all root concepts and their corresponding root_distance - concept_pair = concept_pair.join( - root_concept_relationship, - F.col("common_ancestor_concept_id") == F.col("descendant_concept_id"), - ).select("concept_id_1", "concept_id_2", "distance_1", "distance_2", "root_distance") - - # Compute the semantic similarity - concept_pair_similarity = concept_pair.withColumn( - "semantic_similarity", - 2 * F.col("root_distance") / (2 * F.col("root_distance") + F.col("distance_1") + F.col("distance_2")), - ) - # Find the maximum semantic similarity - concept_pair_similarity = concept_pair_similarity.groupBy("concept_id_1", "concept_id_2").agg( - F.max("semantic_similarity").alias("semantic_similarity") - ) - - return concept_pair_similarity - - -def main( - input_folder: str, - output_folder: str, - domain_table_list: List[str], - date_filter: str, - include_concept_list: bool, -): - """ - Main function to generate the concept similarity table. - - Args: - input_folder (str): The path to the input folder containing raw data. - output_folder (str): The path to the output folder to store the results. - domain_table_list (List[str]): List of domain tables to process. - date_filter (str): Date filter to apply to the data. - include_concept_list (bool): Whether to include a filtered concept list. - """ - - spark = SparkSession.builder.appName("Generate the concept similarity table").getOrCreate() - - logger = logging.getLogger(__name__) - logger.info( - "input_folder: %s\noutput_folder: %s\ndomain_table_list: %s\ndate_filter: " "%s\ninclude_concept_list: %s", - input_folder, - output_folder, - domain_table_list, - date_filter, - include_concept_list, - ) - - concept = preprocess_domain_table(spark, input_folder, CONCEPT) - concept_ancestor = preprocess_domain_table(spark, input_folder, CONCEPT_ANCESTOR) - - # Extract all data points from specified domains - patient_event = extract_data(spark, input_folder, domain_table_list) - - # Calculate information content using unfiltered the patient event dataframe - information_content = compute_information_content(patient_event, concept_ancestor) - - # Filter out concepts that are not required in the required concept_list - if include_concept_list and patient_event: - # Filter out concepts - qualified_concepts = F.broadcast(preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH)) - - patient_event = patient_event.join(qualified_concepts, "standard_concept_id").select("standard_concept_id") - - concept_pair_similarity = compute_semantic_similarity(spark, patient_event, concept, concept_ancestor) - - # Compute the information content based similarity scores - concept_pair_ic_similarity = compute_information_content_similarity( - concept_pair_similarity, information_content, concept_ancestor - ) - - concept_pair_similarity_columns = [concept_pair_similarity[f] for f in concept_pair_similarity.schema.fieldNames()] - concept_pair_ic_similarity_columns = [ - f for f in concept_pair_ic_similarity.schema.fieldNames() if "concept_id" not in f - ] - - # Join two dataframes to get the final result - concept_pair_similarity = concept_pair_similarity.join( - concept_pair_ic_similarity, - (concept_pair_similarity["concept_id_1"] == concept_pair_ic_similarity["concept_id_1"]) - & (concept_pair_similarity["concept_id_2"] == concept_pair_ic_similarity["concept_id_2"]), - ).select(concept_pair_similarity_columns + concept_pair_ic_similarity_columns) - - concept_pair_similarity.write.mode("overwrite").parquet(os.path.join(output_folder, CONCEPT_SIMILARITY_PATH)) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Arguments for generate Concept Similarity Table") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "-tc", - "--domain_table_list", - dest="domain_table_list", - nargs="+", - action="store", - help="The list of domain tables you want to download", - type=validate_table_names, - required=True, - ) - parser.add_argument( - "-d", - "--date_filter", - dest="date_filter", - type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), - action="store", - required=False, - default="2018-01-01", - ) - parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") - - ARGS = parser.parse_args() - - main( - ARGS.input_folder, - ARGS.output_folder, - ARGS.domain_table_list, - ARGS.date_filter, - ARGS.include_concept_list, - ) diff --git a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py b/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py deleted file mode 100644 index 82c875bd..00000000 --- a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py +++ /dev/null @@ -1,238 +0,0 @@ -""" -This module generates hierarchical BERT training data based on domain tables from OMOP EHR data. - -It processes patient event data, joins multiple domain tables, filters concepts based on a -minimum number of patients, and creates hierarchical sequence data for BERT training. - -Key Functions: - - preprocess_domain_table: Preprocesses domain tables for data extraction. - - process_measurement: Handles special processing for measurement data. - - join_domain_tables: Joins multiple domain tables into a unified DataFrame. - - create_hierarchical_sequence_data: Generates hierarchical sequence data for training. - -Command-line Arguments: - - input_folder: Path to the directory containing input data. - - output_folder: Path to the directory where the output will be saved. - - domain_table_list: List of domain tables to process. - - date_filter: Optional filter for processing the data based on date. - - max_num_of_visits_per_person: Maximum number of visits per patient to include. - - min_observation_period: Minimum observation period in days for patients to be included. - - include_concept_list: Whether to apply a filter to retain certain concepts. - - include_incomplete_visit: Whether to include incomplete visit records in the training data. -""" - -import datetime -import logging -import os - -from pyspark.sql import SparkSession -from pyspark.sql import functions as F - -from ..config.output_names import PARQUET_DATA_PATH, QUALIFIED_CONCEPT_LIST_PATH -from ..const.common import MEASUREMENT, OBSERVATION_PERIOD, PERSON, REQUIRED_MEASUREMENT, VISIT_OCCURRENCE -from ..utils.spark_utils import ( - create_hierarchical_sequence_data, - join_domain_tables, - preprocess_domain_table, - process_measurement, - validate_table_names, -) - - -def main( - input_folder, - output_folder, - domain_table_list, - date_filter, - max_num_of_visits_per_person, - min_observation_period: int = 360, - include_concept_list: bool = True, - include_incomplete_visit: bool = True, -): - """ - Main function to generate hierarchical BERT training data from domain tables. - - Args: - input_folder (str): The path to the input folder containing raw data. - output_folder (str): The path to the output folder for storing the training data. - domain_table_list (list): A list of domain tables to process (e.g., condition_occurrence). - date_filter (str): Date filter for processing data, default is '2018-01-01'. - max_num_of_visits_per_person (int): The maximum number of visits to include per person. - min_observation_period (int, optional): Minimum observation period in days. Default is 360. - include_concept_list (bool, optional): Whether to filter by concept list. Default is True. - include_incomplete_visit (bool, optional): Whether to include incomplete visits. Default is - True. - - This function preprocesses domain tables, filters and processes measurement data, - and generates hierarchical sequence data for training BERT models on EHR records. - """ - spark = SparkSession.builder.appName("Generate Hierarchical Bert Training Data").getOrCreate() - - logger = logging.getLogger(__name__) - logger.info( - "input_folder: %s\n" - "output_folder: %s\n" - "domain_table_list: %s\n" - "date_filter: %s\n" - "max_num_of_visits_per_person: %s\n" - "min_observation_period: %s\n" - "include_concept_list: %s\n" - "include_incomplete_visit: %s", - input_folder, - output_folder, - domain_table_list, - date_filter, - max_num_of_visits_per_person, - min_observation_period, - include_concept_list, - include_incomplete_visit, - ) - - domain_tables = [] - # Exclude measurement from domain_table_list if exists because we need to process measurement - # in a different way - for domain_table_name in domain_table_list: - if domain_table_name != MEASUREMENT: - domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name)) - - observation_period = ( - preprocess_domain_table(spark, input_folder, OBSERVATION_PERIOD) - .withColumn( - "observation_period_start_date", - F.col("observation_period_start_date").cast("date"), - ) - .withColumn( - "observation_period_end_date", - F.col("observation_period_end_date").cast("date"), - ) - .withColumn( - "period", - F.datediff("observation_period_end_date", "observation_period_start_date"), - ) - .where(F.col("period") >= min_observation_period) - .select("person_id") - ) - - visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - person = preprocess_domain_table(spark, input_folder, PERSON) - - # Filter for the persons that have enough observation period - person = person.join(observation_period, "person_id").select([person[f] for f in person.schema.fieldNames()]) - - # Union all domain table records - patient_events = join_domain_tables(domain_tables) - - column_names = patient_events.schema.fieldNames() - - if include_concept_list and patient_events: - # Filter out concepts - qualified_concepts = F.broadcast(preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH)) - # The select is necessary to make sure the order of the columns is the same as the - # original dataframe - patient_events = patient_events.join(qualified_concepts, "standard_concept_id").select(column_names) - - # Process the measurement table if exists - if MEASUREMENT in domain_table_list: - measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) - required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) - # The select is necessary to make sure the order of the columns is the same as the - # original dataframe, otherwise the union might use the wrong columns - scaled_measurement = process_measurement(spark, measurement, required_measurement).select(column_names) - - if patient_events: - # Union all measurement records together with other domain records - patient_events = patient_events.union(scaled_measurement) - else: - patient_events = scaled_measurement - - # cohort_member_id is the same as the person_id - patient_events = patient_events.withColumn("cohort_member_id", F.col("person_id")) - - sequence_data = create_hierarchical_sequence_data( - person, - visit_occurrence, - patient_events, - date_filter=date_filter, - max_num_of_visits_per_person=max_num_of_visits_per_person, - include_incomplete_visit=include_incomplete_visit, - ) - - sequence_data.write.mode("overwrite").parquet(os.path.join(output_folder, PARQUET_DATA_PATH)) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Arguments for generate training data for Hierarchical Bert") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "-tc", - "--domain_table_list", - dest="domain_table_list", - nargs="+", - action="store", - help="The list of domain tables you want to download", - type=validate_table_names, - required=True, - ) - parser.add_argument( - "-d", - "--date_filter", - dest="date_filter", - type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), - action="store", - required=False, - default="2018-01-01", - ) - parser.add_argument( - "--max_num_of_visits", - dest="max_num_of_visits", - action="store", - type=int, - default=200, - help="Max no.of visits per patient to be included", - required=False, - ) - parser.add_argument( - "--min_observation_period", - dest="min_observation_period", - action="store", - type=int, - default=1, - help="Minimum observation period in days", - required=False, - ) - parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") - parser.add_argument( - "--include_incomplete_visit", - dest="include_incomplete_visit", - action="store_true", - ) - - ARGS = parser.parse_args() - - main( - input_folder=ARGS.input_folder, - output_folder=ARGS.output_folder, - domain_table_list=ARGS.domain_table_list, - date_filter=ARGS.date_filter, - max_num_of_visits_per_person=ARGS.max_num_of_visits, - min_observation_period=ARGS.min_observation_period, - include_concept_list=ARGS.include_concept_list, - include_incomplete_visit=ARGS.include_incomplete_visit, - ) diff --git a/src/cehrbert/spark_apps/generate_included_concept_list.py b/src/cehrbert/spark_apps/generate_included_concept_list.py deleted file mode 100644 index ac81dedd..00000000 --- a/src/cehrbert/spark_apps/generate_included_concept_list.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -This module generates a qualified concept list by processing patient event data across various. - -domain tables (e.g., condition_occurrence, procedure_occurrence, drug_exposure) and applying a -patient frequency filter to retain concepts linked to a minimum number of patients. - -Key Functions: - - preprocess_domain_table: Preprocesses domain tables to prepare for event extraction. - - join_domain_tables: Joins multiple domain tables into a unified DataFrame. - - main: Coordinates the entire process of reading domain tables, applying frequency filters, - and saving the qualified concept list. - -Command-line Arguments: - - input_folder: Directory containing the input data. - - output_folder: Directory where the qualified concept list will be saved. - - min_num_of_patients: Minimum number of patients linked to a concept for it to be included. - - with_drug_rollup: Boolean flag indicating whether drug concept rollups should be applied. -""" - -import os - -from pyspark.sql import SparkSession -from pyspark.sql import functions as F - -from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH -from ..const.common import MEASUREMENT -from ..utils.spark_utils import join_domain_tables, preprocess_domain_table - -DOMAIN_TABLE_LIST = ["condition_occurrence", "procedure_occurrence", "drug_exposure"] - - -def main(input_folder, output_folder, min_num_of_patients, with_drug_rollup: bool = True): - """ - Main function to generate a qualified concept list based on patient event data from multiple. - - domain tables. - - Args: - input_folder (str): The directory where the input data is stored. - output_folder (str): The directory where the output (qualified concept list) will be saved. - min_num_of_patients (int): Minimum number of patients that a concept must be linked to for - nclusion. - with_drug_rollup (bool): If True, applies drug rollup logic to the drug_exposure domain. - - The function processes patient event data across various domain tables, excludes low-frequency - concepts, and saves the filtered concepts to a specified output folder. - """ - spark = SparkSession.builder.appName("Generate concept list").getOrCreate() - - domain_tables = [] - # Exclude measurement from domain_table_list if exists because we need to process measurement - # in a different way - for domain_table_name in DOMAIN_TABLE_LIST: - if domain_table_name != MEASUREMENT: - domain_tables.append( - preprocess_domain_table( - spark, - input_folder, - domain_table_name, - with_drug_rollup=with_drug_rollup, - ) - ) - - # Union all domain table records - patient_events = join_domain_tables(domain_tables) - - # Filter out concepts that are linked to less than 100 patients - qualified_concepts = ( - patient_events.where("visit_occurrence_id IS NOT NULL") - .groupBy("standard_concept_id") - .agg(F.countDistinct("person_id").alias("freq")) - .where(F.col("freq") >= min_num_of_patients) - ) - - qualified_concepts.write.mode("overwrite").parquet(os.path.join(output_folder, QUALIFIED_CONCEPT_LIST_PATH)) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Arguments for generate concept list to be included") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "--min_num_of_patients", - dest="min_num_of_patients", - action="store", - type=int, - default=0, - help="Min no.of patients linked to concepts to be included", - required=False, - ) - parser.add_argument("--with_drug_rollup", dest="with_drug_rollup", action="store_true") - - ARGS = parser.parse_args() - - main( - ARGS.input_folder, - ARGS.output_folder, - ARGS.min_num_of_patients, - ARGS.with_drug_rollup, - ) diff --git a/src/cehrbert/spark_apps/generate_information_content.py b/src/cehrbert/spark_apps/generate_information_content.py deleted file mode 100644 index eadbdd73..00000000 --- a/src/cehrbert/spark_apps/generate_information_content.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -This module generates an information content table based on a list of domain tables from OMOP data. - -It processes patient event data, calculates the frequency of each concept, and computes information -conten using the concept ancestor hierarchy. The results are written to a specified output path. - -Key Functions: - - preprocess_domain_table: Preprocess the domain tables for analysis. - - join_domain_tables: Join multiple domain tables to generate a unified patient event table. - - main: Orchestrates the process of reading input data, calculating concept frequencies, - and generating the information content table. - -Command-line Arguments: - - input_folder: The folder containing the raw OMOP domain data. - - output_folder: The folder where the results will be stored. - - domain_table_list: A list of OMOP domain tables to include in the analysis. - - date_filter: Optional date filter for processing the data. -""" - -import datetime -import logging -import os - -from pyspark.sql import SparkSession -from pyspark.sql import functions as F - -from ..config.output_names import INFORMATION_CONTENT_DATA_PATH -from ..const.common import CONCEPT_ANCESTOR -from ..utils.spark_utils import join_domain_tables, preprocess_domain_table, validate_table_names - - -def main(input_folder, output_folder, domain_table_list, date_filter): - """Create the information content table. - - Keyword arguments: - domain_tables -- the array containing the OMOP domain tables except visit_occurrence - concept_id_frequency_output -- the path for writing the concept frequency output - - This function creates the information content table based on the given domain tables - """ - - spark = SparkSession.builder.appName("Generate the information content table").getOrCreate() - - logger = logging.getLogger(__name__) - logger.info( - "input_folder: %s\noutput_folder: %s\ndomain_table_list: %s\ndate_filter: %s", - input_folder, - output_folder, - domain_table_list, - date_filter, - ) - - concept_ancestor = preprocess_domain_table(spark, input_folder, CONCEPT_ANCESTOR) - domain_tables = [] - for domain_table_name in domain_table_list: - domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name)) - - patient_events = join_domain_tables(domain_tables) - - # Remove all concept_id records - patient_events = patient_events.where("standard_concept_id <> 0") - - # Get the total count - total_count = patient_events.distinct().count() - - # Count the frequency of each concept - concept_frequency = patient_events.distinct().groupBy("standard_concept_id").count() - - # left join b/w descendent_concept_id and the standard_concept_id in the concept freq table - freq_df = ( - concept_frequency.join( - concept_ancestor, - F.col("descendant_concept_id") == F.col("standard_concept_id"), - ) - .groupBy("ancestor_concept_id") - .sum("count") - .withColumnRenamed("ancestor_concept_id", "concept_id") - .withColumnRenamed("sum(count)", "count") - ) - - # Calculate information content for each concept - information_content = freq_df.withColumn("information_content", (-F.log(F.col("count") / total_count))).withColumn( - "probability", F.col("count") / total_count - ) - - information_content.write.mode("overwrite").parquet(os.path.join(output_folder, INFORMATION_CONTENT_DATA_PATH)) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Arguments for generate training data for Bert") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "-tc", - "--domain_table_list", - dest="domain_table_list", - nargs="+", - action="store", - help="The list of domain tables you want to download", - type=validate_table_names, - required=True, - ) - parser.add_argument( - "-d", - "--date_filter", - dest="date_filter", - type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), - action="store", - required=False, - default="2018-01-01", - ) - - ARGS = parser.parse_args() - - main(ARGS.input_folder, ARGS.output_folder, ARGS.domain_table_list, ARGS.date_filter) diff --git a/src/cehrbert/spark_apps/generate_required_labs.py b/src/cehrbert/spark_apps/generate_required_labs.py deleted file mode 100644 index a096e70a..00000000 --- a/src/cehrbert/spark_apps/generate_required_labs.py +++ /dev/null @@ -1,109 +0,0 @@ -import os - -from pyspark.sql import SparkSession - -from ..const.common import CONCEPT, MEASUREMENT, REQUIRED_MEASUREMENT -from ..utils.spark_utils import F, W, argparse, preprocess_domain_table - - -def main(input_folder, output_folder, num_of_numeric_labs, num_of_categorical_labs): - spark = SparkSession.builder.appName("Generate required labs").getOrCreate() - - # Load measurement as a dataframe in pyspark - measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) - concept = preprocess_domain_table(spark, input_folder, CONCEPT) - - # Create the local measurement view - measurement.createOrReplaceTempView("measurement") - - # Create the local concept view - concept.createOrReplaceTempView("concept") - - popular_labs = spark.sql( - """ - SELECT - m.measurement_concept_id, - c.concept_name, - COUNT(*) AS freq, - SUM(CASE WHEN m.value_as_number IS NOT NULL THEN 1 ELSE 0 END) / COUNT(*) AS numeric_percentage, - SUM(CASE WHEN m.value_as_concept_id IS NOT NULL AND m.value_as_concept_id <> 0 THEN 1 ELSE 0 END) / COUNT(*) AS categorical_percentage - FROM measurement AS m - JOIN concept AS c - ON m.measurement_concept_id = c.concept_id - WHERE m.measurement_concept_id <> 0 - GROUP BY m.measurement_concept_id, c.concept_name - ORDER BY COUNT(*) DESC - """ - ) - - # Cache the dataframe for faster computation in the below transformations - popular_labs.cache() - - popular_numeric_labs = ( - popular_labs.withColumn("is_numeric", F.col("numeric_percentage") >= 0.5) - .where("is_numeric") - .withColumn("rn", F.row_number().over(W.orderBy(F.desc("freq")))) - .where(F.col("rn") <= num_of_numeric_labs) - .drop("rn") - ) - - popular_categorical_labs = ( - popular_labs.withColumn("is_categorical", F.col("categorical_percentage") >= 0.5) - .where("is_categorical") - .withColumn("is_numeric", ~F.col("is_categorical")) - .withColumn("rn", F.row_number().over(W.orderBy(F.desc("freq")))) - .where(F.col("rn") <= num_of_categorical_labs) - .drop("is_categorical") - .drop("rn") - ) - - popular_numeric_labs.unionAll(popular_categorical_labs).write.mode("overwrite").parquet( - os.path.join(output_folder, REQUIRED_MEASUREMENT) - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Arguments for generate " "required labs to be included") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "--num_of_numeric_labs", - dest="num_of_numeric_labs", - action="store", - type=int, - default=100, - help="The top most popular numeric labs to be included", - required=False, - ) - parser.add_argument( - "--num_of_categorical_labs", - dest="num_of_categorical_labs", - action="store", - type=int, - default=100, - help="The top most popular categorical labs to be included", - required=False, - ) - - ARGS = parser.parse_args() - - main( - ARGS.input_folder, - ARGS.output_folder, - ARGS.num_of_numeric_labs, - ARGS.num_of_categorical_labs, - ) diff --git a/src/cehrbert/spark_apps/generate_training_data.py b/src/cehrbert/spark_apps/generate_training_data.py deleted file mode 100644 index 5f2d16ba..00000000 --- a/src/cehrbert/spark_apps/generate_training_data.py +++ /dev/null @@ -1,343 +0,0 @@ -import datetime -import os -import shutil - -from pyspark.sql import SparkSession -from pyspark.sql.window import Window - -from ..spark_apps.decorators.patient_event_decorator import AttType -from ..utils.spark_utils import ( - MEASUREMENT, - REQUIRED_MEASUREMENT, - F, - W, - argparse, - create_sequence_data, - create_sequence_data_with_att, - join_domain_tables, - logging, - preprocess_domain_table, - process_measurement, - validate_table_names, -) - -VISIT_OCCURRENCE = "visit_occurrence" -PERSON = "person" -DEATH = "death" - - -def main( - input_folder, - output_folder, - domain_table_list, - date_filter, - include_visit_type, - is_new_patient_representation, - exclude_visit_tokens, - is_classic_bert, - include_prolonged_stay, - include_concept_list: bool, - gpt_patient_sequence: bool, - apply_age_filter: bool, - include_death: bool, - att_type: AttType, - include_sequence_information_content: bool = False, - exclude_demographic: bool = False, - use_age_group: bool = False, - with_drug_rollup: bool = True, - include_inpatient_hour_token: bool = False, - continue_from_events: bool = False, -): - spark = SparkSession.builder.appName("Generate CEHR-BERT Training Data").getOrCreate() - - logger = logging.getLogger(__name__) - logger.info( - f"input_folder: {input_folder}\n" - f"output_folder: {output_folder}\n" - f"domain_table_list: {domain_table_list}\n" - f"date_filter: {date_filter}\n" - f"include_visit_type: {include_visit_type}\n" - f"is_new_patient_representation: {is_new_patient_representation}\n" - f"exclude_visit_tokens: {exclude_visit_tokens}\n" - f"is_classic_bert: {is_classic_bert}\n" - f"include_prolonged_stay: {include_prolonged_stay}\n" - f"include_concept_list: {include_concept_list}\n" - f"gpt_patient_sequence: {gpt_patient_sequence}\n" - f"apply_age_filter: {apply_age_filter}\n" - f"include_death: {include_death}\n" - f"att_type: {att_type}\n" - f"exclude_demographic: {exclude_demographic}\n" - f"use_age_group: {use_age_group}\n" - f"with_drug_rollup: {with_drug_rollup}\n" - ) - - domain_tables = [] - for domain_table_name in domain_table_list: - if domain_table_name != MEASUREMENT: - domain_tables.append( - preprocess_domain_table( - spark, - input_folder, - domain_table_name, - with_drug_rollup=with_drug_rollup, - ) - ) - - visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - visit_occurrence = visit_occurrence.select( - "visit_occurrence_id", - "visit_start_date", - "visit_start_datetime", - "visit_end_date", - "visit_concept_id", - "person_id", - "discharged_to_concept_id", - ) - person = preprocess_domain_table(spark, input_folder, PERSON) - birth_datetime_udf = F.coalesce("birth_datetime", F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp")) - person = person.select( - "person_id", - birth_datetime_udf.alias("birth_datetime"), - "race_concept_id", - "gender_concept_id", - ) - - visit_occurrence_person = visit_occurrence.join(person, "person_id").withColumn( - "age", - F.ceil(F.months_between(F.col("visit_start_date"), F.col("birth_datetime")) / F.lit(12)), - ) - visit_occurrence_person = visit_occurrence_person.drop("birth_datetime") - - death = preprocess_domain_table(spark, input_folder, DEATH) if include_death else None - - patient_events = join_domain_tables(domain_tables) - - if include_concept_list and patient_events: - column_names = patient_events.schema.fieldNames() - # Filter out concepts - qualified_concepts = preprocess_domain_table(spark, input_folder, "qualified_concept_list").select( - "standard_concept_id" - ) - - patient_events = patient_events.join(qualified_concepts, "standard_concept_id").select(column_names) - - # Process the measurement table if exists - if MEASUREMENT in domain_table_list: - measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) - required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) - # The select is necessary to make sure the order of the columns is the same as the - # original dataframe, otherwise the union might use the wrong columns - scaled_measurement = process_measurement(spark, measurement, required_measurement, output_folder) - - if patient_events: - # Union all measurement records together with other domain records - patient_events = patient_events.unionByName(scaled_measurement) - else: - patient_events = scaled_measurement - - patient_events = ( - patient_events.join(visit_occurrence_person, "visit_occurrence_id") - .select( - [patient_events[fieldName] for fieldName in patient_events.schema.fieldNames()] - + ["visit_concept_id", "age"] - ) - .withColumn("cohort_member_id", F.col("person_id")) - ) - - # Apply the age security measure - # We only keep the patient records, whose corresponding age is less than 90 - if apply_age_filter: - patient_events = patient_events.where(F.col("age") < 90) - - if not continue_from_events: - patient_events.write.mode("overwrite").parquet(os.path.join(output_folder, "all_patient_events")) - - patient_events = spark.read.parquet(os.path.join(output_folder, "all_patient_events")) - - if is_new_patient_representation: - sequence_data = create_sequence_data_with_att( - patient_events, - visit_occurrence_person, - date_filter=date_filter, - include_visit_type=include_visit_type, - exclude_visit_tokens=exclude_visit_tokens, - patient_demographic=person if gpt_patient_sequence else None, - death=death, - att_type=att_type, - exclude_demographic=exclude_demographic, - use_age_group=use_age_group, - include_inpatient_hour_token=include_inpatient_hour_token, - ) - else: - sequence_data = create_sequence_data( - patient_events, - date_filter=date_filter, - include_visit_type=include_visit_type, - classic_bert_seq=is_classic_bert, - ) - - if include_prolonged_stay: - udf = F.when( - F.col("visit_concept_id").isin([9201, 262, 9203]), - F.coalesce( - (F.datediff("visit_end_date", "visit_start_date") > 7).cast("int"), - F.lit(0), - ), - ).otherwise(F.lit(0)) - visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - visit_occurrence = ( - visit_occurrence.withColumn("prolonged_length_stay", udf) - .select("person_id", "prolonged_length_stay") - .withColumn( - "prolonged_length_stay", - F.max("prolonged_length_stay").over(W.partitionBy("person_id")), - ) - .distinct() - ) - sequence_data = sequence_data.join(visit_occurrence, "person_id") - - if include_sequence_information_content: - concept_df = patient_events.select("person_id", F.col("standard_concept_id").alias("concept_id")) - concept_freq = ( - concept_df.groupBy("concept_id") - .count() - .withColumn("prob", F.col("count") / F.sum("count").over(Window.partitionBy())) - .withColumn("ic", -F.log("prob")) - ) - - patient_ic_df = concept_df.join(concept_freq, "concept_id").groupby("person_id").agg(F.mean("ic").alias("ic")) - - sequence_data = sequence_data.join(patient_ic_df, "person_id") - - patient_splits_folder = os.path.join(input_folder, "patient_splits") - if os.path.exists(patient_splits_folder): - patient_splits = spark.read.parquet(patient_splits_folder) - sequence_data.join(patient_splits, "person_id").write.mode("overwrite").parquet( - os.path.join(output_folder, "patient_sequence", "temp") - ) - sequence_data = spark.read.parquet(os.path.join(output_folder, "patient_sequence", "temp")) - sequence_data.where('split="train"').write.mode("overwrite").parquet( - os.path.join(output_folder, "patient_sequence/train") - ) - sequence_data.where('split="test"').write.mode("overwrite").parquet( - os.path.join(output_folder, "patient_sequence/test") - ) - shutil.rmtree(os.path.join(output_folder, "patient_sequence", "temp")) - else: - sequence_data.write.mode("overwrite").parquet(os.path.join(output_folder, "patient_sequence")) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Arguments for generate training data for Bert") - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the raw data is", - required=True, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "-tc", - "--domain_table_list", - dest="domain_table_list", - nargs="+", - action="store", - help="The list of domain tables you want to download", - type=validate_table_names, - required=True, - ) - parser.add_argument( - "-d", - "--date_filter", - dest="date_filter", - type=lambda s: datetime.datetime.strptime(s, "%Y-%m-%d"), - action="store", - required=False, - default="2018-01-01", - ) - parser.add_argument( - "-iv", - "--include_visit_type", - dest="include_visit_type", - action="store_true", - help="Specify whether to include visit types for generating the training data", - ) - parser.add_argument( - "-ip", - "--is_new_patient_representation", - dest="is_new_patient_representation", - action="store_true", - help="Specify whether to generate the sequence of EHR records using the new patient " "representation", - ) - parser.add_argument( - "-ib", - "--is_classic_bert_sequence", - dest="is_classic_bert_sequence", - action="store_true", - help="Specify whether to generate the sequence of EHR records using the classic BERT " "sequence", - ) - parser.add_argument( - "-ev", - "--exclude_visit_tokens", - dest="exclude_visit_tokens", - action="store_true", - help="Specify whether or not to exclude the VS and VE tokens", - ) - parser.add_argument( - "--include_prolonged_length_stay", - dest="include_prolonged_stay", - action="store_true", - help="Specify whether or not to include the data for the second learning objective for " "Med-BERT", - ) - parser.add_argument("--include_concept_list", dest="include_concept_list", action="store_true") - parser.add_argument("--gpt_patient_sequence", dest="gpt_patient_sequence", action="store_true") - parser.add_argument("--apply_age_filter", dest="apply_age_filter", action="store_true") - parser.add_argument("--include_death", dest="include_death", action="store_true") - parser.add_argument("--exclude_demographic", dest="exclude_demographic", action="store_true") - parser.add_argument("--use_age_group", dest="use_age_group", action="store_true") - parser.add_argument("--with_drug_rollup", dest="with_drug_rollup", action="store_true") - parser.add_argument( - "--include_inpatient_hour_token", - dest="include_inpatient_hour_token", - action="store_true", - ) - parser.add_argument("--continue_from_events", dest="continue_from_events", action="store_true") - parser.add_argument( - "--att_type", - dest="att_type", - action="store", - choices=[e.value for e in AttType], - ) - - ARGS = parser.parse_args() - - main( - ARGS.input_folder, - ARGS.output_folder, - ARGS.domain_table_list, - ARGS.date_filter, - ARGS.include_visit_type, - ARGS.is_new_patient_representation, - ARGS.exclude_visit_tokens, - ARGS.is_classic_bert_sequence, - ARGS.include_prolonged_stay, - ARGS.include_concept_list, - ARGS.gpt_patient_sequence, - ARGS.apply_age_filter, - ARGS.include_death, - AttType(ARGS.att_type), - exclude_demographic=ARGS.exclude_demographic, - use_age_group=ARGS.use_age_group, - with_drug_rollup=ARGS.with_drug_rollup, - include_inpatient_hour_token=ARGS.include_inpatient_hour_token, - continue_from_events=ARGS.continue_from_events, - ) diff --git a/src/cehrbert/spark_apps/legacy/__init__.py b/src/cehrbert/spark_apps/legacy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/spark_apps/legacy/mortality.py b/src/cehrbert/spark_apps/legacy/mortality.py deleted file mode 100644 index 19c0db39..00000000 --- a/src/cehrbert/spark_apps/legacy/mortality.py +++ /dev/null @@ -1,183 +0,0 @@ -import pyspark.sql.functions as f -from pyspark.sql import DataFrame - -from ..cohorts.spark_app_base import LastVisitCohortBuilderBase -from ..spark_parse_args import create_spark_args - -QUALIFIED_DEATH_DATE_QUERY = """ -WITH max_death_date_cte AS -( - SELECT - person_id, - MAX(death_date) AS death_date - FROM global_temp.death - GROUP BY person_id -) - -SELECT - dv.person_id, - dv.death_date -FROM -( - SELECT DISTINCT - d.person_id, - d.death_date, - FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_start_date) DESC) AS last_visit_start_date - FROM max_death_date_cte AS d - JOIN global_temp.visit_occurrence AS v - ON d.person_id = v.person_id -) dv -WHERE dv.last_visit_start_date <= dv.death_date -""" - -COHORT_QUERY_TEMPLATE = """ -WITH last_visit_cte AS ( - SELECT - v.*, - COUNT(CASE WHEN DATE(v.visit_start_date) >= DATE_SUB(index_date, {observation_period}) - AND DATE(v.visit_start_date) < index_date - THEN 1 ELSE NULL END) OVER (PARTITION BY v.person_id) AS num_of_visits - FROM - ( - SELECT DISTINCT - v.person_id, - v.visit_start_date, - FIRST(v.visit_occurrence_id) OVER(PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date) DESC) AS visit_occurrence_id, - FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date) DESC) AS index_date, - FIRST(v.discharge_to_concept_id) OVER(PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date) DESC) AS discharge_to_concept_id, - FIRST(DATE(v.visit_start_date)) OVER(PARTITION BY v.person_id - ORDER BY DATE(v.visit_start_date)) AS earliest_visit_start_date - FROM global_temp.visit_occurrence AS v - -- Need to make sure the there is enough observation for the observation window. - -- 1) the earliest visit_start_date needs to occur before the observation period. - -- 2) there needs to be at least 2 visit_occurrences for every 360 days (1 year) - ) v -) - -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.index_date, - YEAR(v.index_date) - p.year_of_birth AS age, - p.gender_concept_id, - p.race_concept_id, - CAST(ISNOTNULL(d.person_id) AS INT) AS label -FROM last_visit_cte AS v -JOIN global_temp.person AS p - ON v.person_id = p.person_id -LEFT JOIN global_temp.death AS d - ON v.person_id = d.person_id -WHERE v.index_date BETWEEN '{date_lower_bound}' AND '{date_upper_bound}' - AND v.discharge_to_concept_id = 8536 --discharge to home - AND YEAR(v.earliest_visit_start_date) <= YEAR(DATE_SUB(index_date, {observation_period})) - --AND v.num_of_visits >= {num_of_visits} -""" - -DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] - -COHORT_TABLE = "cohort" -DEATH = "death" -PERSON = "person" -VISIT_OCCURRENCE = "visit_occurrence" -DEPENDENCY_LIST = [DEATH, PERSON, VISIT_OCCURRENCE] - - -class MortalityCohortBuilder(LastVisitCohortBuilderBase): - - def preprocess_dependencies(self): - self.spark.sql(QUALIFIED_DEATH_DATE_QUERY).createOrReplaceGlobalTempView(DEATH) - - num_of_visits = (self._observation_window // 360) + 1 - - cohort_query = COHORT_QUERY_TEMPLATE.format( - date_lower_bound=self._date_lower_bound, - date_upper_bound=self._date_upper_bound, - observation_period=self._observation_window, - num_of_visits=num_of_visits, - ) - - cohort = self.spark.sql(cohort_query) - cohort.createOrReplaceGlobalTempView(COHORT_TABLE) - - self._dependency_dict[COHORT_TABLE] = cohort - - def create_incident_cases(self): - cohort = self._dependency_dict[COHORT_TABLE] - return cohort.where(f.col("label") == 1) - - def create_control_cases(self): - cohort = self._dependency_dict[COHORT_TABLE] - return cohort.where(f.col("label") == 0) - - def create_matching_control_cases(self, incident_cases: DataFrame, control_cases: DataFrame): - """ - Do not match for control and simply what's in the control cases. - - :param incident_cases: - :param control_cases: - :return: - """ - return control_cases - - -def main( - cohort_name, - input_folder, - output_folder, - date_lower_bound, - date_upper_bound, - age_lower_bound, - age_upper_bound, - observation_window, - prediction_window, - hold_off_window, - index_date_match_window, - include_visit_type, - is_feature_concept_frequency, - is_roll_up_concept, -): - cohort_builder = MortalityCohortBuilder( - cohort_name, - input_folder, - output_folder, - date_lower_bound, - date_upper_bound, - age_lower_bound, - age_upper_bound, - observation_window, - prediction_window, - hold_off_window, - index_date_match_window, - DOMAIN_TABLE_LIST, - DEPENDENCY_LIST, - True, - include_visit_type, - is_feature_concept_frequency, - is_roll_up_concept, - ) - - cohort_builder.build() - - -if __name__ == "__main__": - spark_args = create_spark_args() - - main( - spark_args.cohort_name, - spark_args.input_folder, - spark_args.output_folder, - spark_args.date_lower_bound, - spark_args.date_upper_bound, - spark_args.lower_bound, - spark_args.upper_bound, - spark_args.observation_window, - spark_args.prediction_window, - spark_args.hold_off_window, - spark_args.index_date_match_window, - spark_args.include_visit_type, - spark_args.is_feature_concept_frequency, - spark_args.is_roll_up_concept, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/__init__.py b/src/cehrbert/spark_apps/prediction_cohorts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py b/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py deleted file mode 100644 index 2edd5a10..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/afib_ischemic_stroke.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..cohorts import atrial_fibrillation, ischemic_stroke -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] - -if __name__ == "__main__": - create_prediction_cohort( - create_spark_args(), - atrial_fibrillation.query_builder(), - ischemic_stroke.query_builder(), - DOMAIN_TABLE_LIST, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py deleted file mode 100644 index 93296cd1..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/cad_cabg_cohort.py +++ /dev/null @@ -1,18 +0,0 @@ -from ..cohorts import cabg -from ..cohorts import coronary_artery_disease as cad -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] - -if __name__ == "__main__": - spark_args = create_spark_args() - - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort( - spark_args, - cad.query_builder(spark_args), - cabg.query_builder(spark_args), - ehr_table_list, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py deleted file mode 100644 index 767babbc..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/cad_hf_cohort.py +++ /dev/null @@ -1,18 +0,0 @@ -from ..cohorts import coronary_artery_disease as cad -from ..cohorts import heart_failure as hf -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - -if __name__ == "__main__": - spark_args = create_spark_args() - - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort(spark_args, cad.query_builder(spark_args), hf.query_builder(), ehr_table_list) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py b/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py deleted file mode 100644 index 327ca843..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/copd_readmission.py +++ /dev/null @@ -1,69 +0,0 @@ -from ..cohorts.query_builder import AncestorTableSpec, QueryBuilder, QuerySpec -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -COPD_HOSPITALIZATION_QUERY = """ -WITH copd_conditions AS ( - SELECT DISTINCT - descendant_concept_id AS concept_id - FROM global_temp.concept_ancestor AS ca - WHERE ca.ancestor_concept_id in (255573, 258780) -) - -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.visit_end_date AS index_date -FROM global_temp.visit_occurrence AS v -JOIN global_temp.condition_occurrence AS co - ON v.visit_occurrence_id = co.visit_occurrence_id -JOIN copd_conditions AS copd - ON co.condition_concept_id = copd.concept_id -WHERE v.visit_concept_id IN (9201, 262) --inpatient, er-inpatient - AND v.discharged_to_concept_id = 8536 --discharge to home - AND v.visit_start_date <= co.condition_start_date -""" - -HOSPITALIZATION_QUERY = """ -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.visit_start_date AS index_date -FROM global_temp.visit_occurrence AS v -WHERE v.visit_concept_id IN (9201, 262) --inpatient, er-inpatient -""" - -COPD_HOSPITALIZATION_COHORT = "copd_readmission" -HOSPITALIZATION_COHORT = "hospitalization" -DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] -DOMAIN_TABLE_LIST = ["condition_occurrence"] - - -def main(spark_args): - copd_inpatient_query = QuerySpec( - table_name=COPD_HOSPITALIZATION_COHORT, - query_template=COPD_HOSPITALIZATION_QUERY, - parameters={}, - ) - copd_inpatient = QueryBuilder( - cohort_name=COPD_HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=copd_inpatient_query, - ) - - hospitalization_query = QuerySpec( - table_name=HOSPITALIZATION_COHORT, - query_template=HOSPITALIZATION_QUERY, - parameters={}, - ) - hospitalization = QueryBuilder( - cohort_name=HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_query, - ) - - create_prediction_cohort(spark_args, copd_inpatient, hospitalization, DOMAIN_TABLE_LIST) - - -if __name__ == "__main__": - main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py b/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py deleted file mode 100644 index 74f6c549..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/covid_death.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..cohorts import covid_inpatient, death -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] - -if __name__ == "__main__": - create_prediction_cohort( - create_spark_args(), - covid_inpatient.query_builder(), - death.query_builder(), - DOMAIN_TABLE_LIST, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py b/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py deleted file mode 100644 index 1063fcb1..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/covid_ventilation.py +++ /dev/null @@ -1,13 +0,0 @@ -from ..cohorts import covid, ventilation -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = ["condition_occurrence", "drug_exposure", "procedure_occurrence"] - -if __name__ == "__main__": - create_prediction_cohort( - create_spark_args(), - covid.query_builder(), - ventilation.query_builder(), - DOMAIN_TABLE_LIST, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py b/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py deleted file mode 100644 index 58d80401..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/discharge_home_death.py +++ /dev/null @@ -1,22 +0,0 @@ -from ..cohorts import death -from ..cohorts import last_visit_discharged_home as last -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - -if __name__ == "__main__": - spark_args = create_spark_args() - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort( - spark_args, - last.query_builder(spark_args), - death.query_builder(), - ehr_table_list, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py b/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py deleted file mode 100644 index 42b2baed..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/hf_readmission.py +++ /dev/null @@ -1,79 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -HEART_FAILURE_HOSPITALIZATION_QUERY = """ -WITH hf_concepts AS ( - SELECT DISTINCT - descendant_concept_id AS concept_id - FROM global_temp.concept_ancestor AS ca - WHERE ca.ancestor_concept_id = 316139 -) - -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.visit_end_date AS index_date -FROM global_temp.visit_occurrence AS v -JOIN global_temp.condition_occurrence AS co - ON v.visit_occurrence_id = co.visit_occurrence_id -JOIN hf_concepts AS hf - ON co.condition_concept_id = hf.concept_id -WHERE v.visit_concept_id IN (9201, 262, 8971, 8920) --inpatient, er-inpatient - AND v.discharged_to_concept_id NOT IN (4216643, 44814650, 8717, 8970, 8971) -- TBD - --AND v.discharge_to_concept_id IN (8536, 8863, 4161979) -- Home, Skilled Nursing Facility, and Patient discharged alive - AND v.visit_start_date <= co.condition_start_date - AND v.visit_end_date >= '{date_lower_bound}' -""" - -HOSPITALIZATION_QUERY = """ -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.visit_start_date AS index_date -FROM global_temp.visit_occurrence AS v -WHERE v.visit_concept_id IN (9201, 262, 8971, 8920) --inpatient, er-inpatient -""" - -HF_HOSPITALIZATION_COHORT = "hf_hospitalization" -HOSPITALIZATION_COHORT = "hospitalization" -DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - - -def main(spark_args): - hf_inpatient_target_query = QuerySpec( - table_name=HF_HOSPITALIZATION_COHORT, - query_template=HEART_FAILURE_HOSPITALIZATION_QUERY, - parameters={"date_lower_bound": spark_args.date_lower_bound}, - ) - - hf_inpatient_target_querybuilder = QueryBuilder( - cohort_name=HF_HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hf_inpatient_target_query, - ) - - hospitalization_query = QuerySpec( - table_name=HOSPITALIZATION_COHORT, - query_template=HOSPITALIZATION_QUERY, - parameters={}, - ) - hospitalization = QueryBuilder( - cohort_name=HOSPITALIZATION_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_query, - ) - - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort(spark_args, hf_inpatient_target_querybuilder, hospitalization, ehr_table_list) - - -if __name__ == "__main__": - main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py deleted file mode 100644 index 918f2322..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -HOSPITALIZATION_OUTCOME_QUERY = """ -SELECT DISTINCT - v.person_id, - visit_start_date AS index_date, - visit_occurrence_id -FROM global_temp.visit_occurrence AS v -WHERE v.visit_concept_id IN (9201, 262) -""" - -HOSPITALIZATION_TARGET_QUERY = """ -WITH INDEX_VISIT_TABLE AS -( - SELECT DISTINCT - person_id, - FIRST(visit_start_date) OVER (PARTITION BY person_id ORDER BY visit_start_date, visit_occurrence_id) AS index_date, - FIRST(visit_occurrence_id) OVER (PARTITION BY person_id ORDER BY visit_start_date, visit_occurrence_id) AS visit_occurrence_id - FROM global_temp.visit_occurrence - WHERE visit_end_date >= visit_start_date -), -HOSPITAL_TARGET AS -( - SELECT DISTINCT - iv.person_id, - iv.index_date, - count(distinct case when v1.visit_concept_id IN (9201, 262) then v1.visit_occurrence_id end) as num_of_hospitalizations, - count(distinct v1.visit_occurrence_id) as num_of_visits - FROM INDEX_VISIT_TABLE iv - JOIN global_temp.visit_occurrence v1 - ON v1.person_id = iv.person_id AND DATEDIFF(v1.visit_start_date, iv.index_date) <= {total_window} - JOIN global_temp.observation_period op - ON iv.person_id = op.person_id - AND DATEDIFF(CAST(op.observation_period_end_date AS date), CAST(op.observation_period_start_date AS date)) >= {total_window} - GROUP BY iv.person_id, iv.index_date -) - -SELECT - person_id, - index_date, - CAST(null AS INT) AS visit_occurrence_id -FROM HOSPITAL_TARGET -WHERE num_of_visits between 2 and 30 - AND index_date >= '{date_lower_bound}' -""" - -HOSPITALIZATION_TARGET_COHORT = "hospitalization_target" -HOSPITALIZATION_OUTCOME_COHORT = "hospitalization_outcome" -DEPENDENCY_LIST = ["person", "condition_occurrence", "visit_occurrence"] -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - - -def main(spark_args): - total_window = spark_args.observation_window + spark_args.hold_off_window - hospitalization_target_query = QuerySpec( - table_name=HOSPITALIZATION_TARGET_COHORT, - query_template=HOSPITALIZATION_TARGET_QUERY, - parameters={ - "total_window": total_window, - "date_lower_bound": spark_args.date_lower_bound, - }, - ) - hospitalization_querybuilder = QueryBuilder( - cohort_name=HOSPITALIZATION_TARGET_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_target_query, - ) - - hospitalization_outcome_query = QuerySpec( - table_name=HOSPITALIZATION_OUTCOME_COHORT, - query_template=HOSPITALIZATION_OUTCOME_QUERY, - parameters={}, - ) - hospitalization_outcome_querybuilder = QueryBuilder( - cohort_name=HOSPITALIZATION_OUTCOME_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_outcome_query, - ) - - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort( - spark_args, - hospitalization_querybuilder, - hospitalization_outcome_querybuilder, - ehr_table_list, - ) - - -if __name__ == "__main__": - main(create_spark_args()) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py b/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py deleted file mode 100644 index 6e2452b8..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/hospitalization_mortality.py +++ /dev/null @@ -1,81 +0,0 @@ -from ..cohorts.query_builder import QueryBuilder, QuerySpec -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DEPENDENCY_LIST = ["visit_occurrence"] -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - -HOSPITALIZATION_QUERY = """ -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.index_date, - v.expired -FROM -( - SELECT - v.person_id, - v.visit_occurrence_id, - v.visit_end_date AS index_date, - CASE - WHEN v.discharged_to_concept_id == 4216643 THEN 1 - ELSE 0 - END AS expired, - ROW_NUMBER() OVER(PARTITION BY v.person_id ORDER BY DATE(v.visit_end_date) DESC) AS rn - FROM global_temp.visit_occurrence AS v - WHERE v.visit_concept_id IN (9201, 262) --inpatient, er-inpatient - AND v.visit_end_date IS NOT NULL -) AS v - WHERE v.rn = 1 AND v.index_date >= '{date_lower_bound}' -""" - -MORTALITY_QUERY = """ -SELECT DISTINCT - v.person_id, - v.visit_occurrence_id, - v.index_date AS index_date -FROM global_temp.{target_table_name} AS v -WHERE expired = 1 -""" - -HOSPITALIZATION_TARGET_COHORT = "hospitalization_target" -MORTALITY_COHORT = "hospitalization_mortality" - -if __name__ == "__main__": - spark_args = create_spark_args() - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - hospitalization_target_query = QuerySpec( - table_name=HOSPITALIZATION_TARGET_COHORT, - query_template=HOSPITALIZATION_QUERY, - parameters={"date_lower_bound": spark_args.date_lower_bound}, - ) - - hospitalization_querybuilder = QueryBuilder( - cohort_name=HOSPITALIZATION_TARGET_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_target_query, - ) - - hospitalization_mortality_query = QuerySpec( - table_name=MORTALITY_COHORT, - query_template=MORTALITY_QUERY, - parameters={"target_table_name": HOSPITALIZATION_TARGET_COHORT}, - ) - hospitalization_mortality_querybuilder = QueryBuilder( - cohort_name=MORTALITY_COHORT, - dependency_list=DEPENDENCY_LIST, - query=hospitalization_mortality_query, - ) - - create_prediction_cohort( - spark_args, - hospitalization_querybuilder, - hospitalization_mortality_querybuilder, - ehr_table_list, - ) diff --git a/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py b/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py deleted file mode 100644 index c737cfdb..00000000 --- a/src/cehrbert/spark_apps/prediction_cohorts/t2dm_hf_cohort.py +++ /dev/null @@ -1,18 +0,0 @@ -from ..cohorts import heart_failure as hf -from ..cohorts import type_two_diabietes as t2dm -from ..cohorts.spark_app_base import create_prediction_cohort -from ..spark_parse_args import create_spark_args - -DOMAIN_TABLE_LIST = [ - "condition_occurrence", - "drug_exposure", - "procedure_occurrence", - "measurement", -] - -if __name__ == "__main__": - spark_args = create_spark_args() - - ehr_table_list = spark_args.ehr_table_list if spark_args.ehr_table_list else DOMAIN_TABLE_LIST - - create_prediction_cohort(spark_args, t2dm.query_builder(spark_args), hf.query_builder(), ehr_table_list) diff --git a/src/cehrbert/spark_apps/spark_parse_args.py b/src/cehrbert/spark_apps/spark_parse_args.py deleted file mode 100644 index 518b67ad..00000000 --- a/src/cehrbert/spark_apps/spark_parse_args.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -This module defines functions for parsing command-line arguments for Spark applications. - -that generate cohort definitions. It includes argument parsing for cohort specifications, -date ranges, patient information, and EHR data extraction settings. - -Functions: - valid_date: Validates and converts a date string into a datetime object. - create_spark_args: Defines and parses command-line arguments for cohort generation and EHR - processing. -""" - -import argparse -import datetime - -from .decorators.patient_event_decorator import AttType - - -def valid_date(s): - """ - Validates and converts a date string into a datetime object. - - Args: - s (str): The date string in the format 'YYYY-MM-DD'. - Returns: - datetime.datetime: The parsed date. - Raises: - argparse.ArgumentTypeError: If the date string is not valid. - """ - try: - return datetime.datetime.strptime(s, "%Y-%m-%d") - except ValueError as e: - raise argparse.ArgumentTypeError(e) - - -def create_spark_args(): - """ - Defines and parses the command-line arguments for Spark applications. - - that generate cohort definitions based on EHR data. - - Returns: - argparse.Namespace: The parsed arguments as a namespace object containing the user - inputs. - - Command-line Arguments: - -c, --cohort_name: The name of the cohort being generated. - -i, --input_folder: The folder path containing the input data. - --patient_splits_folder: The folder containing patient splits data. - -o, --output_folder: The folder path to store the output data. - --ehr_table_list: List of EHR domain tables for feature extraction. - -dl, --date_lower_bound: The lower bound for date filtering. - -du, --date_upper_bound: The upper bound for date filtering. - -l, --age_lower_bound: The minimum age filter for cohort inclusion. - -u, --age_upper_bound: The maximum age filter for cohort inclusion. - -ow, --observation_window: The observation window duration in days. - -pw, --prediction_window: The prediction window duration in days. - -ps, --prediction_start_days: The start point of the prediction window in days. - -hw, --hold_off_window: The hold-off window for excluding certain features. - --num_of_visits: The minimum number of visits required for cohort inclusion. - --num_of_concepts: The minimum number of concepts required for cohort inclusion. - -iw, --is_window_post_index: Whether the observation window is post-index. - -iv, --include_visit_type: Whether to include visit types in feature generation. - -ev, --exclude_visit_tokens: Whether to exclude certain visit tokens (VS and VE). - -f, --is_feature_concept_frequency: Whether the features are based on concept counts. - -ir, --is_roll_up_concept: Whether to roll up concepts to their ancestors. - -ip, --is_new_patient_representation: Whether to use a new patient representation. - --gpt_patient_sequence: Whether to generate GPT sequences for EHR records. - -ih, --is_hierarchical_bert: Whether to use a hierarchical patient representation for BERT. - -cbs, --classic_bert_seq: Whether to use classic BERT sequence representation with SEP. - --is_first_time_outcome: Whether the outcome is the first-time occurrence. - --is_remove_index_prediction_starts: Whether to remove outcomes between index and prediction - start. - --is_prediction_window_unbounded: Whether the prediction window end is unbounded. - --is_observation_window_unbounded: Whether the observation window is unbounded. - --include_concept_list: Whether to apply filters for low-frequency concepts. - --allow_measurement_only: Whether patients with only measurements are allowed. - --is_population_estimation: Whether cohort is constructed for population-level estimation. - --att_type: The attribute type used for cohort definitions. - --exclude_demographic: Whether to exclude demographic prompts in patient sequences. - --use_age_group: Whether to represent age using age groups in patient sequences. - --single_contribution: Whether patients should contribute only once to the training data. - """ - parser = argparse.ArgumentParser(description="Arguments for spark applications for generating cohort definitions") - parser.add_argument( - "-c", - "--cohort_name", - dest="cohort_name", - action="store", - help="The cohort name", - required=True, - ) - parser.add_argument( - "-i", - "--input_folder", - dest="input_folder", - action="store", - help="The path for your input_folder where the sequence data is", - required=True, - ) - parser.add_argument( - "--patient_splits_folder", - dest="patient_splits_folder", - action="store", - help="The folder that contains the patient_splits data", - required=False, - ) - parser.add_argument( - "-o", - "--output_folder", - dest="output_folder", - action="store", - help="The path for your output_folder", - required=True, - ) - parser.add_argument( - "--ehr_table_list", - dest="ehr_table_list", - nargs="+", - action="store", - help="The list of domain tables you want to include for feature extraction", - required=False, - ) - parser.add_argument( - "-dl", - "--date_lower_bound", - dest="date_lower_bound", - action="store", - help="The date filter lower bound for filtering training data", - required=True, - type=valid_date, - ) - parser.add_argument( - "-du", - "--date_upper_bound", - dest="date_upper_bound", - action="store", - help="The date filter upper bound for filtering training data", - required=True, - type=valid_date, - ) - parser.add_argument( - "-l", - "--age_lower_bound", - dest="age_lower_bound", - action="store", - help="The age lower bound", - required=False, - type=int, - default=0, - ) - parser.add_argument( - "-u", - "--age_upper_bound", - dest="age_upper_bound", - action="store", - help="The age upper bound", - required=False, - type=int, - default=100, - ) - parser.add_argument( - "-ow", - "--observation_window", - dest="observation_window", - action="store", - help="The observation window in days for extracting features", - required=False, - type=int, - default=365, - ) - parser.add_argument( - "-pw", - "--prediction_window", - dest="prediction_window", - action="store", - help="The prediction window in which the prediction is made", - required=False, - type=int, - default=180, - ) - parser.add_argument( - "-ps", - "--prediction_start_days", - dest="prediction_start_days", - action="store", - help="The prediction start days in which the prediction is made", - required=False, - type=int, - default=1, - ) - parser.add_argument( - "-hw", - "--hold_off_window", - dest="hold_off_window", - action="store", - help="The hold off window for excluding the features", - required=False, - type=int, - default=0, - ) - parser.add_argument( - "--num_of_visits", - dest="num_of_visits", - action="store", - help="The number of visits to qualify for the inclusion of the cohorts", - required=False, - type=int, - default=0, - ) - parser.add_argument( - "--num_of_concepts", - dest="num_of_concepts", - action="store", - help="The number of concepts to qualify for the inclusion of the cohorts", - required=False, - type=int, - default=0, - ) - parser.add_argument( - "-iw", - "--is_window_post_index", - dest="is_window_post_index", - action="store_true", - help="Indicate if the observation window is pre/post the index date", - ) - parser.add_argument( - "-iv", - "--include_visit_type", - dest="include_visit_type", - action="store_true", - help="Specify whether to include visit types for " "generating the training data", - ) - parser.add_argument( - "-ev", - "--exclude_visit_tokens", - dest="exclude_visit_tokens", - action="store_true", - help="Specify whether or not to exclude the VS and VE tokens", - ) - parser.add_argument( - "-f", - "--is_feature_concept_frequency", - dest="is_feature_concept_frequency", - action="store_true", - help="Specify whether the features are concept counts or not", - ) - parser.add_argument( - "-ir", - "--is_roll_up_concept", - dest="is_roll_up_concept", - action="store_true", - help="Specify whether to roll up the concepts to their ancestors", - ) - parser.add_argument( - "-ip", - "--is_new_patient_representation", - dest="is_new_patient_representation", - action="store_true", - help="Specify whether to generate the sequence of " "EHR records using the new patient representation", - ) - parser.add_argument( - "--gpt_patient_sequence", - dest="gpt_patient_sequence", - action="store_true", - help="Specify whether to generate the GPT sequence of " "EHR records using the new patient representation", - ) - parser.add_argument( - "-ih", - "--is_hierarchical_bert", - dest="is_hierarchical_bert", - action="store_true", - help="Specify whether to generate the sequence of " "EHR records using the hierarchical patient representation", - ) - parser.add_argument( - "-cbs", - "--classic_bert_seq", - dest="classic_bert_seq", - action="store_true", - help="Specify whether to generate the sequence of " - "EHR records using the classic BERT sequence representation where " - "visits are separated by a SEP token", - ) - parser.add_argument( - "--is_first_time_outcome", - dest="is_first_time_outcome", - action="store_true", - help="is the outcome the first time occurrence?", - ) - parser.add_argument( - "--is_remove_index_prediction_starts", - dest="is_remove_index_prediction_starts", - action="store_true", - help="is outcome between index_date and prediction start window removed?", - ) - parser.add_argument( - "--is_prediction_window_unbounded", - dest="is_prediction_window_unbounded", - action="store_true", - help="is the end of the prediction window unbounded?", - ) - parser.add_argument( - "--is_observation_window_unbounded", - dest="is_observation_window_unbounded", - action="store_true", - help="is the observation window unbounded?", - ) - parser.add_argument( - "--include_concept_list", - dest="include_concept_list", - action="store_true", - help="Apply the filter to remove low-frequency concepts", - ) - parser.add_argument( - "--allow_measurement_only", - dest="allow_measurement_only", - action="store_true", - help="Indicate whether we allow patients with measurements only", - ) - parser.add_argument( - "--is_population_estimation", - dest="is_population_estimation", - action="store_true", - help="Indicate whether the cohort is constructed for population level " "estimation", - ) - parser.add_argument( - "--att_type", - dest="att_type", - action="store", - choices=[e.value for e in AttType], - ) - parser.add_argument( - "--exclude_demographic", - dest="exclude_demographic", - action="store_true", - help="Indicate whether we should exclude the demographic prompt of the patient sequence", - ) - parser.add_argument( - "--use_age_group", - dest="use_age_group", - action="store_true", - help="Indicate whether we should age group to represent the age at the first event in the " "patient sequence", - ) - parser.add_argument( - "--single_contribution", - dest="single_contribution", - action="store_true", - help="Indicate whether we should contribute once to the training data", - ) - return parser.parse_args() diff --git a/src/cehrbert/spark_apps/sql_templates.py b/src/cehrbert/spark_apps/sql_templates.py deleted file mode 100644 index 559bc5ca..00000000 --- a/src/cehrbert/spark_apps/sql_templates.py +++ /dev/null @@ -1,42 +0,0 @@ -measurement_unit_stats_query = """ -WITH measurement_percentile AS -( - SELECT - m.measurement_concept_id, - m.unit_concept_id, - MEAN(m.value_as_number) AS mean_value, - MIN(m.value_as_number) AS min_value, - MAX(m.value_as_number) AS max_value, - percentile_approx(m.value_as_number, 0.01) AS lower_bound, - percentile_approx(m.value_as_number, 0.99) AS upper_bound - FROM measurement AS m - WHERE EXISTS ( - SELECT - 1 - FROM required_measurement AS r - WHERE r.measurement_concept_id = m.measurement_concept_id - AND r.is_numeric = true - ) - GROUP BY m.measurement_concept_id, m.unit_concept_id -) - -SELECT - m.measurement_concept_id, - m.unit_concept_id, - MEAN(m.value_as_number) AS value_mean, - STDDEV(m.value_as_number) AS value_stddev, - COUNT(*) AS measurement_freq, - FIRST(mp.lower_bound) AS lower_bound, - FIRST(mp.upper_bound) AS upper_bound -FROM measurement AS m -JOIN measurement_percentile AS mp - ON m.measurement_concept_id = mp.measurement_concept_id - AND m.unit_concept_id = mp.unit_concept_id -WHERE - m.value_as_number BETWEEN mp.lower_bound AND mp.upper_bound - AND m.visit_occurrence_id IS NOT NULL - AND m.unit_concept_id <> 0 - AND m.measurement_concept_id <> 0 -GROUP BY m.measurement_concept_id, m.unit_concept_id -HAVING COUNT(*) >= 100 -""" diff --git a/src/cehrbert/trainers/model_trainer.py b/src/cehrbert/trainers/model_trainer.py index 5ff74d27..44d0f636 100644 --- a/src/cehrbert/trainers/model_trainer.py +++ b/src/cehrbert/trainers/model_trainer.py @@ -8,12 +8,12 @@ import pandas as pd import tensorflow as tf -from ..data_generators.data_generator_base import AbstractDataGeneratorBase -from ..models.layers.custom_layers import get_custom_objects -from ..models.loss_schedulers import CosineLRSchedule -from ..utils.checkpoint_utils import MODEL_CONFIG_FILE, get_checkpoint_epoch -from ..utils.logging_utils import logging -from ..utils.model_utils import create_folder_if_not_exist, log_function_decorator, save_training_history +from cehrbert.data_generators.data_generator_base import AbstractDataGeneratorBase +from cehrbert.models.layers.custom_layers import get_custom_objects +from cehrbert.models.loss_schedulers import CosineLRSchedule +from cehrbert.utils.checkpoint_utils import MODEL_CONFIG_FILE, get_checkpoint_epoch +from cehrbert.utils.logging_utils import logging +from cehrbert.utils.model_utils import create_folder_if_not_exist, log_function_decorator, save_training_history class AbstractModel(ABC): diff --git a/src/cehrbert/trainers/train_cehr_bert.py b/src/cehrbert/trainers/train_cehr_bert.py index 0ddd3018..e1159d37 100644 --- a/src/cehrbert/trainers/train_cehr_bert.py +++ b/src/cehrbert/trainers/train_cehr_bert.py @@ -1,17 +1,17 @@ import tensorflow as tf from tensorflow.keras import optimizers -from ..data_generators.data_generator_base import ( +from cehrbert.data_generators.data_generator_base import ( BertDataGenerator, BertVisitPredictionDataGenerator, MedBertDataGenerator, ) -from ..keras_transformer.bert import MaskedPenalizedSparseCategoricalCrossentropy, masked_perplexity -from ..models.bert_models import transformer_bert_model -from ..models.bert_models_visit_prediction import transformer_bert_model_visit_prediction -from ..models.parse_args import create_parse_args_base_bert -from ..trainers.model_trainer import AbstractConceptEmbeddingTrainer -from ..utils.model_utils import tokenize_one_field +from cehrbert.keras_transformer.bert import MaskedPenalizedSparseCategoricalCrossentropy, masked_perplexity +from cehrbert.models.bert_models import transformer_bert_model +from cehrbert.models.bert_models_visit_prediction import transformer_bert_model_visit_prediction +from cehrbert.models.parse_args import create_parse_args_base_bert +from cehrbert.trainers.model_trainer import AbstractConceptEmbeddingTrainer +from cehrbert.utils.model_utils import tokenize_one_field class VanillaBertTrainer(AbstractConceptEmbeddingTrainer): diff --git a/src/cehrbert/utils/spark_utils.py b/src/cehrbert/utils/spark_utils.py deleted file mode 100644 index 3ad03c11..00000000 --- a/src/cehrbert/utils/spark_utils.py +++ /dev/null @@ -1,1404 +0,0 @@ -import argparse -from os import path -from typing import List, Tuple - -import pandas as pd -import pyspark.sql.functions as F -import pyspark.sql.types as T -from pyspark.sql import Window as W -from pyspark.sql.functions import broadcast -from pyspark.sql.pandas.functions import pandas_udf - -from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH -from ..const.common import ( - CATEGORICAL_MEASUREMENT, - CDM_TABLES, - MEASUREMENT, - PERSON, - REQUIRED_MEASUREMENT, - UNKNOWN_CONCEPT, - VISIT_OCCURRENCE, -) -from ..spark_apps.decorators.patient_event_decorator import ( - AttType, - DeathEventDecorator, - DemographicPromptDecorator, - PatientEventAttDecorator, - PatientEventBaseDecorator, - time_token_func, -) -from ..spark_apps.sql_templates import measurement_unit_stats_query -from ..utils.logging_utils import logging - -DOMAIN_KEY_FIELDS = { - "condition_occurrence_id": [ - ( - "condition_concept_id", - "condition_start_date", - "condition_start_datetime", - "condition", - ) - ], - "procedure_occurrence_id": [("procedure_concept_id", "procedure_date", "procedure_datetime", "procedure")], - "drug_exposure_id": [ - ( - "drug_concept_id", - "drug_exposure_start_date", - "drug_exposure_start_datetime", - "drug", - ) - ], - "measurement_id": [ - ( - "measurement_concept_id", - "measurement_date", - "measurement_datetime", - "measurement", - ) - ], - "death_date": [("person_id", "death_date", "death_datetime", "death")], - "visit_concept_id": [ - ("visit_concept_id", "visit_start_date", "visit"), - ("discharged_to_concept_id", "visit_end_date", "visit"), - ], -} - -LOGGER = logging.getLogger(__name__) - - -def get_key_fields(domain_table) -> List[Tuple[str, str, str, str]]: - field_names = domain_table.schema.fieldNames() - for k, v in DOMAIN_KEY_FIELDS.items(): - if k in field_names: - return v - return [ - ( - get_concept_id_field(domain_table), - get_domain_date_field(domain_table), - get_domain_datetime_field(domain_table), - get_domain_field(domain_table), - ) - ] - - -def get_domain_date_field(domain_table): - # extract the domain start_date column - return [f for f in domain_table.schema.fieldNames() if "date" in f][0] - - -def get_domain_datetime_field(domain_table): - # extract the domain start_date column - return [f for f in domain_table.schema.fieldNames() if "datetime" in f][0] - - -def get_concept_id_field(domain_table): - return [f for f in domain_table.schema.fieldNames() if "concept_id" in f][0] - - -def get_domain_field(domain_table): - return get_concept_id_field(domain_table).replace("_concept_id", "") - - -def create_file_path(input_folder, table_name): - if input_folder[-1] == "/": - file_path = input_folder + table_name - else: - file_path = input_folder + "/" + table_name - - return file_path - - -def join_domain_tables(domain_tables): - """Standardize the format of OMOP domain tables using a time frame. - - Keyword arguments: - domain_tables -- the array containing the OMOOP domain tabls except visit_occurrence - except measurement - - The the output columns of the domain table is converted to the same standard format as the following - (person_id, standard_concept_id, date, lower_bound, upper_bound, domain). - In this case, co-occurrence is defined as those concept ids that have co-occurred - within the same time window of a patient. - """ - patient_event = None - - for domain_table in domain_tables: - # extract the domain concept_id from the table fields. E.g. condition_concept_id from - # condition_occurrence extract the domain start_date column extract the name of the table - for ( - concept_id_field, - date_field, - datetime_field, - table_domain_field, - ) in get_key_fields(domain_table): - # Remove records that don't have a date or standard_concept_id - sub_domain_table = domain_table.where(F.col(date_field).isNotNull()).where( - F.col(concept_id_field).isNotNull() - ) - datetime_field_udf = F.to_timestamp(F.coalesce(datetime_field, date_field), "yyyy-MM-dd HH:mm:ss") - sub_domain_table = ( - sub_domain_table.where(F.col(concept_id_field).cast("string") != "0") - .withColumn("date", F.to_date(F.col(date_field))) - .withColumn("datetime", datetime_field_udf) - ) - - sub_domain_table = sub_domain_table.select( - sub_domain_table["person_id"], - sub_domain_table[concept_id_field].alias("standard_concept_id"), - sub_domain_table["date"].cast("date"), - sub_domain_table["datetime"], - sub_domain_table["visit_occurrence_id"], - F.lit(table_domain_field).alias("domain"), - F.lit(-1).alias("concept_value"), - ).distinct() - - # Remove "Patient Died" from condition_occurrence - if sub_domain_table == "condition_occurrence": - sub_domain_table = sub_domain_table.where("condition_concept_id != 4216643") - - if patient_event is None: - patient_event = sub_domain_table - else: - patient_event = patient_event.union(sub_domain_table) - - return patient_event - - -def preprocess_domain_table( - spark, - input_folder, - domain_table_name, - with_diagnosis_rollup=False, - with_drug_rollup=True, -): - domain_table = spark.read.parquet(create_file_path(input_folder, domain_table_name)) - - if "concept" in domain_table_name.lower(): - return domain_table - - # lowercase the schema fields - domain_table = domain_table.select([F.col(f_n).alias(f_n.lower()) for f_n in domain_table.schema.fieldNames()]) - - for f_n in domain_table.schema.fieldNames(): - if "date" in f_n and "datetime" not in f_n: - # convert date columns to the date type - domain_table = domain_table.withColumn(f_n, F.to_date(f_n)) - elif "datetime" in f_n: - # convert date columns to the datetime type - domain_table = domain_table.withColumn(f_n, F.to_timestamp(f_n)) - - if domain_table_name == "visit_occurrence": - # This is CDM 5.2, we need to rename this column to be CDM 5.3 compatible - if "discharge_to_concept_id" in domain_table.schema.fieldNames(): - domain_table = domain_table.withColumnRenamed("discharge_to_concept_id", "discharged_to_concept_id") - - if with_drug_rollup: - if ( - domain_table_name == "drug_exposure" - and path.exists(create_file_path(input_folder, "concept")) - and path.exists(create_file_path(input_folder, "concept_ancestor")) - ): - concept = spark.read.parquet(create_file_path(input_folder, "concept")) - concept_ancestor = spark.read.parquet(create_file_path(input_folder, "concept_ancestor")) - domain_table = roll_up_to_drug_ingredients(domain_table, concept, concept_ancestor) - - if with_diagnosis_rollup: - if ( - domain_table_name == "condition_occurrence" - and path.exists(create_file_path(input_folder, "concept")) - and path.exists(create_file_path(input_folder, "concept_relationship")) - ): - concept = spark.read.parquet(create_file_path(input_folder, "concept")) - concept_relationship = spark.read.parquet(create_file_path(input_folder, "concept_relationship")) - domain_table = roll_up_diagnosis(domain_table, concept, concept_relationship) - - if ( - domain_table_name == "procedure_occurrence" - and path.exists(create_file_path(input_folder, "concept")) - and path.exists(create_file_path(input_folder, "concept_ancestor")) - ): - concept = spark.read.parquet(create_file_path(input_folder, "concept")) - concept_ancestor = spark.read.parquet(create_file_path(input_folder, "concept_ancestor")) - domain_table = roll_up_procedure(domain_table, concept, concept_ancestor) - - return domain_table - - -def roll_up_to_drug_ingredients(drug_exposure, concept, concept_ancestor): - # lowercase the schema fields - drug_exposure = drug_exposure.select([F.col(f_n).alias(f_n.lower()) for f_n in drug_exposure.schema.fieldNames()]) - - drug_ingredient = ( - drug_exposure.select("drug_concept_id") - .distinct() - .join(concept_ancestor, F.col("drug_concept_id") == F.col("descendant_concept_id")) - .join(concept, F.col("ancestor_concept_id") == F.col("concept_id")) - .where(concept["concept_class_id"] == "Ingredient") - .select(F.col("drug_concept_id"), F.col("concept_id").alias("ingredient_concept_id")) - ) - - drug_ingredient_fields = [ - F.coalesce(F.col("ingredient_concept_id"), F.col("drug_concept_id")).alias("drug_concept_id") - ] - drug_ingredient_fields.extend( - [F.col(field_name) for field_name in drug_exposure.schema.fieldNames() if field_name != "drug_concept_id"] - ) - - drug_exposure = drug_exposure.join(drug_ingredient, "drug_concept_id", "left_outer").select(drug_ingredient_fields) - - return drug_exposure - - -def roll_up_diagnosis(condition_occurrence, concept, concept_relationship): - list_3dig_code = [ - "3-char nonbill code", - "3-dig nonbill code", - "3-char billing code", - "3-dig billing code", - "3-dig billing E code", - "3-dig billing V code", - "3-dig nonbill E code", - "3-dig nonbill V code", - ] - - condition_occurrence = condition_occurrence.select( - [F.col(f_n).alias(f_n.lower()) for f_n in condition_occurrence.schema.fieldNames()] - ) - - condition_icd = ( - condition_occurrence.select("condition_source_concept_id") - .distinct() - .join(concept, (F.col("condition_source_concept_id") == F.col("concept_id"))) - .where(concept["domain_id"] == "Condition") - .where(concept["vocabulary_id"] != "SNOMED") - .select( - F.col("condition_source_concept_id"), - F.col("vocabulary_id").alias("child_vocabulary_id"), - F.col("concept_class_id").alias("child_concept_class_id"), - ) - ) - - condition_icd_hierarchy = ( - condition_icd.join( - concept_relationship, - F.col("condition_source_concept_id") == F.col("concept_id_1"), - ) - .join( - concept, - (F.col("concept_id_2") == F.col("concept_id")) & (F.col("concept_class_id").isin(list_3dig_code)), - how="left", - ) - .select( - F.col("condition_source_concept_id").alias("source_concept_id"), - F.col("child_concept_class_id"), - F.col("concept_id").alias("parent_concept_id"), - F.col("concept_name").alias("parent_concept_name"), - F.col("vocabulary_id").alias("parent_vocabulary_id"), - F.col("concept_class_id").alias("parent_concept_class_id"), - ) - .distinct() - ) - - condition_icd_hierarchy = condition_icd_hierarchy.withColumn( - "ancestor_concept_id", - F.when( - F.col("child_concept_class_id").isin(list_3dig_code), - F.col("source_concept_id"), - ).otherwise(F.col("parent_concept_id")), - ).dropna(subset="ancestor_concept_id") - - condition_occurrence_fields = [ - F.col(f_n).alias(f_n.lower()) - for f_n in condition_occurrence.schema.fieldNames() - if f_n != "condition_source_concept_id" - ] - condition_occurrence_fields.append( - F.coalesce(F.col("ancestor_concept_id"), F.col("condition_source_concept_id")).alias( - "condition_source_concept_id" - ) - ) - - condition_occurrence = ( - condition_occurrence.join( - condition_icd_hierarchy, - condition_occurrence["condition_source_concept_id"] == condition_icd_hierarchy["source_concept_id"], - how="left", - ) - .select(condition_occurrence_fields) - .withColumn("condition_concept_id", F.col("condition_source_concept_id")) - ) - return condition_occurrence - - -def roll_up_procedure(procedure_occurrence, concept, concept_ancestor): - def extract_parent_code(concept_code): - return concept_code.split(".")[0] - - parent_code_udf = F.udf(extract_parent_code, T.StringType()) - - procedure_code = ( - procedure_occurrence.select("procedure_source_concept_id") - .distinct() - .join(concept, F.col("procedure_source_concept_id") == F.col("concept_id")) - .where(concept["domain_id"] == "Procedure") - .select( - F.col("procedure_source_concept_id").alias("source_concept_id"), - F.col("vocabulary_id").alias("child_vocabulary_id"), - F.col("concept_class_id").alias("child_concept_class_id"), - F.col("concept_code").alias("child_concept_code"), - ) - ) - - # cpt code rollup - cpt_code = procedure_code.where(F.col("child_vocabulary_id") == "CPT4") - - cpt_hierarchy = ( - cpt_code.join( - concept_ancestor, - cpt_code["source_concept_id"] == concept_ancestor["descendant_concept_id"], - ) - .join(concept, concept_ancestor["ancestor_concept_id"] == concept["concept_id"]) - .where(concept["vocabulary_id"] == "CPT4") - .select( - F.col("source_concept_id"), - F.col("child_concept_class_id"), - F.col("ancestor_concept_id").alias("parent_concept_id"), - F.col("min_levels_of_separation"), - F.col("concept_class_id").alias("parent_concept_class_id"), - ) - ) - - cpt_hierarchy_level_1 = ( - cpt_hierarchy.where(F.col("min_levels_of_separation") == 1) - .where(F.col("child_concept_class_id") == "CPT4") - .where(F.col("parent_concept_class_id") == "CPT4 Hierarchy") - .select(F.col("source_concept_id"), F.col("parent_concept_id")) - ) - - cpt_hierarchy_level_1 = cpt_hierarchy_level_1.join( - concept_ancestor, - (cpt_hierarchy_level_1["source_concept_id"] == concept_ancestor["descendant_concept_id"]) - & (concept_ancestor["min_levels_of_separation"] == 1), - how="left", - ).select( - F.col("source_concept_id"), - F.col("parent_concept_id"), - F.col("ancestor_concept_id").alias("root_concept_id"), - ) - - cpt_hierarchy_level_1 = cpt_hierarchy_level_1.withColumn( - "isroot", - F.when( - cpt_hierarchy_level_1["root_concept_id"] == 45889197, - cpt_hierarchy_level_1["source_concept_id"], - ).otherwise(cpt_hierarchy_level_1["parent_concept_id"]), - ).select(F.col("source_concept_id"), F.col("isroot").alias("ancestor_concept_id")) - - cpt_hierarchy_level_0 = ( - cpt_hierarchy.groupby("source_concept_id") - .max() - .where(F.col("max(min_levels_of_separation)") == 0) - .select(F.col("source_concept_id").alias("cpt_level_0_concept_id")) - ) - - cpt_hierarchy_level_0 = cpt_hierarchy.join( - cpt_hierarchy_level_0, - cpt_hierarchy["source_concept_id"] == cpt_hierarchy_level_0["cpt_level_0_concept_id"], - ).select( - F.col("source_concept_id"), - F.col("parent_concept_id").alias("ancestor_concept_id"), - ) - - cpt_hierarchy_rollup_all = cpt_hierarchy_level_1.union(cpt_hierarchy_level_0).drop_duplicates() - - # ICD code rollup - icd_list = ["ICD9CM", "ICD9Proc", "ICD10CM"] - - procedure_icd = procedure_code.where(F.col("vocabulary_id").isin(icd_list)) - - procedure_icd = ( - procedure_icd.withColumn("parent_concept_code", parent_code_udf(F.col("child_concept_code"))) - .withColumnRenamed("procedure_source_concept_id", "source_concept_id") - .withColumnRenamed("concept_name", "child_concept_name") - .withColumnRenamed("vocabulary_id", "child_vocabulary_id") - .withColumnRenamed("concept_code", "child_concept_code") - .withColumnRenamed("concept_class_id", "child_concept_class_id") - ) - - procedure_icd_map = ( - procedure_icd.join( - concept, - (procedure_icd["parent_concept_code"] == concept["concept_code"]) - & (procedure_icd["child_vocabulary_id"] == concept["vocabulary_id"]), - how="left", - ) - .select("source_concept_id", F.col("concept_id").alias("ancestor_concept_id")) - .distinct() - ) - - # ICD10PCS rollup - procedure_10pcs = procedure_code.where(F.col("vocabulary_id") == "ICD10PCS") - - procedure_10pcs = ( - procedure_10pcs.withColumn("parent_concept_code", F.substring(F.col("child_concept_code"), 1, 3)) - .withColumnRenamed("procedure_source_concept_id", "source_concept_id") - .withColumnRenamed("concept_name", "child_concept_name") - .withColumnRenamed("vocabulary_id", "child_vocabulary_id") - .withColumnRenamed("concept_code", "child_concept_code") - .withColumnRenamed("concept_class_id", "child_concept_class_id") - ) - - procedure_10pcs_map = ( - procedure_10pcs.join( - concept, - (procedure_10pcs["parent_concept_code"] == concept["concept_code"]) - & (procedure_10pcs["child_vocabulary_id"] == concept["vocabulary_id"]), - how="left", - ) - .select("source_concept_id", F.col("concept_id").alias("ancestor_concept_id")) - .distinct() - ) - - # HCPCS rollup --- keep the concept_id itself - procedure_hcpcs = procedure_code.where(F.col("child_vocabulary_id") == "HCPCS") - procedure_hcpcs_map = ( - procedure_hcpcs.withColumn("ancestor_concept_id", F.col("source_concept_id")) - .select("source_concept_id", "ancestor_concept_id") - .distinct() - ) - - procedure_hierarchy = ( - cpt_hierarchy_rollup_all.union(procedure_icd_map) - .union(procedure_10pcs_map) - .union(procedure_hcpcs_map) - .distinct() - ) - procedure_occurrence_fields = [ - F.col(f_n).alias(f_n.lower()) - for f_n in procedure_occurrence.schema.fieldNames() - if f_n != "procedure_source_concept_id" - ] - procedure_occurrence_fields.append( - F.coalesce(F.col("ancestor_concept_id"), F.col("procedure_source_concept_id")).alias( - "procedure_source_concept_id" - ) - ) - - procedure_occurrence = ( - procedure_occurrence.join( - procedure_hierarchy, - procedure_occurrence["procedure_source_concept_id"] == procedure_hierarchy["source_concept_id"], - how="left", - ) - .select(procedure_occurrence_fields) - .withColumn("procedure_concept_id", F.col("procedure_source_concept_id")) - ) - return procedure_occurrence - - -def create_sequence_data(patient_event, date_filter=None, include_visit_type=False, classic_bert_seq=False): - """ - Create a sequence of the events associated with one patient in a chronological order. - - :param patient_event: - :param date_filter: - :param include_visit_type: - :param classic_bert_seq: - :return: - """ - - if date_filter: - patient_event = patient_event.where(F.col("date") >= date_filter) - - # Define a list of custom UDFs for creating custom columns - date_conversion_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") - earliest_visit_date_udf = F.min("date_in_week").over(W.partitionBy("visit_occurrence_id")) - - visit_rank_udf = F.dense_rank().over(W.partitionBy("cohort_member_id", "person_id").orderBy("earliest_visit_date")) - visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 - - # Derive columns - patient_event = ( - patient_event.where("visit_occurrence_id IS NOT NULL") - .withColumn("date_in_week", date_conversion_udf) - .withColumn("earliest_visit_date", earliest_visit_date_udf) - .withColumn("visit_rank_order", visit_rank_udf) - .withColumn("visit_segment", visit_segment_udf) - .withColumn("priority", F.lit(0)) - ) - - if classic_bert_seq: - # Udf for identifying the earliest date associated with a visit_occurrence_id - visit_start_date_udf = F.first("date").over( - W.partitionBy("cohort_member_id", "person_id", "visit_occurrence_id").orderBy("date") - ) - - # Udf for identifying the previous visit_occurrence_id - prev_visit_occurrence_id_udf = F.lag("visit_occurrence_id").over( - W.partitionBy("cohort_member_id", "person_id").orderBy("visit_start_date", "visit_occurrence_id") - ) - - # We can achieve this by overwriting the record with the earliest time stamp - separator_events = ( - patient_event.withColumn("visit_start_date", visit_start_date_udf) - .withColumn("prev_visit_occurrence_id", prev_visit_occurrence_id_udf) - .where("prev_visit_occurrence_id IS NOT NULL") - .where("visit_occurrence_id <> prev_visit_occurrence_id") - .withColumn("domain", F.lit("Separator")) - .withColumn("standard_concept_id", F.lit("SEP")) - .withColumn("priority", F.lit(-1)) - .withColumn("visit_segment", F.lit(0)) - .select(patient_event.schema.fieldNames()) - ) - - # Combine this artificial token SEP with the original data - patient_event = patient_event.union(separator_events) - - order_udf = F.row_number().over( - W.partitionBy("cohort_member_id", "person_id").orderBy( - "earliest_visit_date", - "visit_occurrence_id", - "priority", - "date_in_week", - "standard_concept_id", - ) - ) - # Group the data into sequences - output_columns = [ - "order", - "date_in_week", - "standard_concept_id", - "visit_segment", - "age", - "visit_rank_order", - ] - - if include_visit_type: - output_columns.append("visit_concept_id") - - # Group by data by person_id and put all the events into a list - # The order of the list is determined by the order column - patient_grouped_events = ( - patient_event.withColumn("order", order_udf) - .withColumn("date_concept_id_period", F.struct(output_columns)) - .groupBy("person_id", "cohort_member_id") - .agg( - F.sort_array(F.collect_set("date_concept_id_period")).alias("date_concept_id_period"), - F.min("earliest_visit_date").alias("earliest_visit_date"), - F.max("date").alias("max_event_date"), - F.max("visit_rank_order").alias("num_of_visits"), - F.count("standard_concept_id").alias("num_of_concepts"), - ) - .withColumn( - "orders", - F.col("date_concept_id_period.order").cast(T.ArrayType(T.IntegerType())), - ) - .withColumn("dates", F.col("date_concept_id_period.date_in_week")) - .withColumn("concept_ids", F.col("date_concept_id_period.standard_concept_id")) - .withColumn("visit_segments", F.col("date_concept_id_period.visit_segment")) - .withColumn("ages", F.col("date_concept_id_period.age")) - .withColumn("visit_concept_orders", F.col("date_concept_id_period.visit_rank_order")) - ) - - # Default columns in the output dataframe - columns_for_output = [ - "cohort_member_id", - "person_id", - "earliest_visit_date", - "max_event_date", - "orders", - "dates", - "ages", - "concept_ids", - "visit_segments", - "visit_concept_orders", - "num_of_visits", - "num_of_concepts", - ] - - if include_visit_type: - patient_grouped_events = patient_grouped_events.withColumn( - "visit_concept_ids", F.col("date_concept_id_period.visit_concept_id") - ) - columns_for_output.append("visit_concept_ids") - - return patient_grouped_events.select(columns_for_output) - - -def create_sequence_data_with_att( - patient_events, - visit_occurrence, - date_filter=None, - include_visit_type=False, - exclude_visit_tokens=False, - patient_demographic=None, - death=None, - att_type: AttType = AttType.CEHR_BERT, - exclude_demographic: bool = True, - use_age_group: bool = False, - include_inpatient_hour_token: bool = False, -): - """ - Create a sequence of the events associated with one patient in a chronological order. - - :param patient_events: - :param visit_occurrence: - :param date_filter: - :param include_visit_type: - :param exclude_visit_tokens: - :param patient_demographic: - :param death: - :param att_type: - :param exclude_demographic: - :param use_age_group: - :param include_inpatient_hour_token: - - :return: - """ - if date_filter: - patient_events = patient_events.where(F.col("date").cast("date") >= date_filter) - - decorators = [ - PatientEventBaseDecorator(visit_occurrence), - PatientEventAttDecorator( - visit_occurrence, - include_visit_type, - exclude_visit_tokens, - att_type, - include_inpatient_hour_token, - ), - # DemographicPromptDecorator(patient_demographic), - DeathEventDecorator(death, att_type), - ] - - if not exclude_demographic: - decorators.append(DemographicPromptDecorator(patient_demographic, use_age_group)) - - for decorator in decorators: - patient_events = decorator.decorate(patient_events) - - # add randomness to the order of the concepts that have the same time stamp - order_udf = F.row_number().over( - W.partitionBy("cohort_member_id", "person_id").orderBy( - "visit_rank_order", - "concept_order", - "priority", - "datetime", - "standard_concept_id", - ) - ) - - dense_rank_udf = F.dense_rank().over( - W.partitionBy("cohort_member_id", "person_id").orderBy( - "visit_rank_order", "concept_order", "priority", "datetime" - ) - ) - - # Those columns are derived from the previous decorators - struct_columns = [ - "order", - "record_rank", - "date_in_week", - "standard_concept_id", - "visit_segment", - "age", - "visit_rank_order", - "concept_value_mask", - "concept_value", - "mlm_skip_value", - "visit_concept_id", - "visit_concept_order", - "concept_order", - "priority", - ] - output_columns = [ - "cohort_member_id", - "person_id", - "concept_ids", - "visit_segments", - "orders", - "dates", - "ages", - "visit_concept_orders", - "num_of_visits", - "num_of_concepts", - "concept_value_masks", - "concept_values", - "mlm_skip_values", - "priorities", - "visit_concept_ids", - "visit_rank_orders", - "concept_orders", - "record_ranks", - ] - - patient_grouped_events = ( - patient_events.withColumn("order", order_udf) - .withColumn("record_rank", dense_rank_udf) - .withColumn("data_for_sorting", F.struct(struct_columns)) - .groupBy("cohort_member_id", "person_id") - .agg( - F.sort_array(F.collect_set("data_for_sorting")).alias("data_for_sorting"), - F.max("visit_rank_order").alias("num_of_visits"), - F.count("standard_concept_id").alias("num_of_concepts"), - ) - .withColumn("orders", F.col("data_for_sorting.order").cast(T.ArrayType(T.IntegerType()))) - .withColumn( - "record_ranks", - F.col("data_for_sorting.record_rank").cast(T.ArrayType(T.IntegerType())), - ) - .withColumn("dates", F.col("data_for_sorting.date_in_week")) - .withColumn("concept_ids", F.col("data_for_sorting.standard_concept_id")) - .withColumn("visit_segments", F.col("data_for_sorting.visit_segment")) - .withColumn("ages", F.col("data_for_sorting.age")) - .withColumn("visit_rank_orders", F.col("data_for_sorting.visit_rank_order")) - .withColumn("visit_concept_orders", F.col("data_for_sorting.visit_concept_order")) - .withColumn("concept_orders", F.col("data_for_sorting.concept_order")) - .withColumn("priorities", F.col("data_for_sorting.priority")) - .withColumn("concept_value_masks", F.col("data_for_sorting.concept_value_mask")) - .withColumn("concept_values", F.col("data_for_sorting.concept_value")) - .withColumn("mlm_skip_values", F.col("data_for_sorting.mlm_skip_value")) - .withColumn("visit_concept_ids", F.col("data_for_sorting.visit_concept_id")) - ) - - return patient_grouped_events.select(output_columns) - - -def create_concept_frequency_data(patient_event, date_filter=None): - if date_filter: - patient_event = patient_event.where(F.col("date") >= date_filter) - - take_concept_ids_udf = F.udf(lambda rows: [row[0] for row in rows], T.ArrayType(T.StringType())) - take_freqs_udf = F.udf(lambda rows: [row[1] for row in rows], T.ArrayType(T.IntegerType())) - - num_of_visits_concepts = patient_event.groupBy("cohort_member_id", "person_id").agg( - F.countDistinct("visit_occurrence_id").alias("num_of_visits"), - F.count("standard_concept_id").alias("num_of_concepts"), - ) - - patient_event = ( - patient_event.groupBy("cohort_member_id", "person_id", "standard_concept_id") - .count() - .withColumn("concept_id_freq", F.struct("standard_concept_id", "count")) - .groupBy("cohort_member_id", "person_id") - .agg(F.collect_list("concept_id_freq").alias("sequence")) - .withColumn("concept_ids", take_concept_ids_udf("sequence")) - .withColumn("frequencies", take_freqs_udf("sequence")) - .select("cohort_member_id", "person_id", "concept_ids", "frequencies") - .join(num_of_visits_concepts, ["person_id", "cohort_member_id"]) - ) - - return patient_event - - -def extract_ehr_records( - spark, - input_folder, - domain_table_list, - include_visit_type=False, - with_rollup=False, - include_concept_list=False, -): - """ - Extract the ehr records for domain_table_list from input_folder. - - :param spark: - :param input_folder: - :param domain_table_list: - :param include_visit_type: whether or not to include the visit type to the ehr records - :param with_rollup: whether ot not to roll up the concepts to the parent levels - :param include_concept_list: - :return: - """ - domain_tables = [] - for domain_table_name in domain_table_list: - if domain_table_name != MEASUREMENT: - domain_tables.append(preprocess_domain_table(spark, input_folder, domain_table_name, with_rollup)) - patient_ehr_records = join_domain_tables(domain_tables) - - if include_concept_list and patient_ehr_records: - # Filter out concepts - qualified_concepts = preprocess_domain_table(spark, input_folder, QUALIFIED_CONCEPT_LIST_PATH).select( - "standard_concept_id" - ) - - patient_ehr_records = patient_ehr_records.join(qualified_concepts, "standard_concept_id") - - # Process the measurement table if exists - if MEASUREMENT in domain_table_list: - measurement = preprocess_domain_table(spark, input_folder, MEASUREMENT) - required_measurement = preprocess_domain_table(spark, input_folder, REQUIRED_MEASUREMENT) - scaled_measurement = process_measurement(spark, measurement, required_measurement) - - if patient_ehr_records: - # Union all measurement records together with other domain records - patient_ehr_records = patient_ehr_records.union(scaled_measurement) - else: - patient_ehr_records = scaled_measurement - - patient_ehr_records = patient_ehr_records.where("visit_occurrence_id IS NOT NULL").distinct() - - person = preprocess_domain_table(spark, input_folder, PERSON) - person = person.withColumn( - "birth_datetime", - F.coalesce( - "birth_datetime", - F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), - ), - ) - patient_ehr_records = patient_ehr_records.join(person, "person_id").withColumn( - "age", - F.ceil(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12)), - ) - if include_visit_type: - visit_occurrence = preprocess_domain_table(spark, input_folder, VISIT_OCCURRENCE) - patient_ehr_records = patient_ehr_records.join(visit_occurrence, "visit_occurrence_id").select( - patient_ehr_records["person_id"], - patient_ehr_records["standard_concept_id"], - patient_ehr_records["date"], - patient_ehr_records["visit_occurrence_id"], - patient_ehr_records["domain"], - visit_occurrence["visit_concept_id"], - patient_ehr_records["age"], - ) - - return patient_ehr_records - - -def build_ancestry_table_for(spark, concept_ids): - initial_query = """ - SELECT - cr.concept_id_1 AS ancestor_concept_id, - cr.concept_id_2 AS descendant_concept_id, - 1 AS distance - FROM global_temp.concept_relationship AS cr - WHERE cr.concept_id_1 in ({concept_ids}) AND cr.relationship_id = 'Subsumes' - """ - - recurring_query = """ - SELECT - i.ancestor_concept_id AS ancestor_concept_id, - cr.concept_id_2 AS descendant_concept_id, - i.distance + 1 AS distance - FROM global_temp.ancestry_table AS i - JOIN global_temp.concept_relationship AS cr - ON i.descendant_concept_id = cr.concept_id_1 AND cr.relationship_id = 'Subsumes' - LEFT JOIN global_temp.ancestry_table AS i2 - ON cr.concept_id_2 = i2.descendant_concept_id - WHERE i2.descendant_concept_id IS NULL - """ - - union_query = """ - SELECT - * - FROM global_temp.ancestry_table - - UNION - - SELECT - * - FROM global_temp.candidate - """ - - ancestry_table = spark.sql(initial_query.format(concept_ids=",".join([str(c) for c in concept_ids]))) - ancestry_table.createOrReplaceGlobalTempView("ancestry_table") - - candidate_set = spark.sql(recurring_query) - candidate_set.createOrReplaceGlobalTempView("candidate") - - while candidate_set.count() != 0: - spark.sql(union_query).createOrReplaceGlobalTempView("ancestry_table") - candidate_set = spark.sql(recurring_query) - candidate_set.createOrReplaceGlobalTempView("candidate") - - ancestry_table = spark.sql( - """ - SELECT - * - FROM global_temp.ancestry_table - """ - ) - - spark.sql( - """ - DROP VIEW global_temp.ancestry_table - """ - ) - - return ancestry_table - - -def get_descendant_concept_ids(spark, concept_ids): - """ - Query concept_ancestor table to get all descendant_concept_ids for the given list of concept_ids. - - :param spark: - :param concept_ids: - :return: - """ - sanitized_concept_ids = [int(c) for c in concept_ids] - # Join the sanitized IDs into a string for the query - concept_ids_str = ",".join(map(str, sanitized_concept_ids)) - # Construct and execute the SQL query using the sanitized string - descendant_concept_ids = spark.sql( - f""" - SELECT DISTINCT - c.* - FROM global_temp.concept_ancestor AS ca - JOIN global_temp.concept AS c - ON ca.descendant_concept_id = c.concept_id - WHERE ca.ancestor_concept_id IN ({concept_ids_str}) - """ - ) - return descendant_concept_ids - - -def get_standard_concept_ids(spark, concept_ids): - standard_concept_ids = spark.sql( - """ - SELECT DISTINCT - c.* - FROM global_temp.concept_relationship AS cr - JOIN global_temp.concept AS c - ON ca.concept_id_2 = c.concept_id AND cr.relationship_id = 'Maps to' - WHERE ca.concept_id_1 IN ({concept_ids}) - """.format( - concept_ids=",".join([str(c) for c in concept_ids]) - ) - ) - return standard_concept_ids - - -def get_table_column_refs(dataframe): - return [dataframe[fieldName] for fieldName in dataframe.schema.fieldNames()] - - -def create_hierarchical_sequence_data( - person, - visit_occurrence, - patient_events, - date_filter=None, - max_num_of_visits_per_person=None, - include_incomplete_visit=True, - allow_measurement_only=False, -): - """ - This creates a hierarchical data frame for the hierarchical bert model. - - :param person: - :param visit_occurrence: - :param patient_events: - :param date_filter: - :param max_num_of_visits_per_person: - :param include_incomplete_visit: - :param allow_measurement_only: - :return: - """ - - if date_filter: - visit_occurrence = visit_occurrence.where(F.col("visit_start_date").cast("date") >= date_filter) - - # Construct visit information with the person demographic - visit_occurrence_person = create_visit_person_join(person, visit_occurrence, include_incomplete_visit) - - # Retrieve all visit column references - visit_column_refs = get_table_column_refs(visit_occurrence_person) - - # Construct the patient event column references - pat_col_refs = [ - F.coalesce(patient_events["cohort_member_id"], visit_occurrence["person_id"]).alias("cohort_member_id"), - F.coalesce(patient_events["standard_concept_id"], F.lit(UNKNOWN_CONCEPT)).alias("standard_concept_id"), - F.coalesce(patient_events["date"], visit_occurrence["visit_start_date"]).alias("date"), - F.coalesce(patient_events["domain"], F.lit("unknown")).alias("domain"), - F.coalesce(patient_events["concept_value"], F.lit(-1.0)).alias("concept_value"), - ] - - # Convert standard_concept_id to string type, this is needed for the tokenization - # Calculate the age w.r.t to the event - patient_events = ( - visit_occurrence_person.join(patient_events, "visit_occurrence_id", "left_outer") - .select(visit_column_refs + pat_col_refs) - .withColumn("standard_concept_id", F.col("standard_concept_id").cast("string")) - .withColumn( - "age", - F.ceil(F.months_between(F.col("date"), F.col("birth_datetime")) / F.lit(12)), - ) - .withColumn("concept_value_mask", (F.col("domain") == MEASUREMENT).cast("int")) - .withColumn( - "mlm_skip", - (F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])).cast("int"), - ) - .withColumn("condition_mask", (F.col("domain") == "condition").cast("int")) - ) - - if not allow_measurement_only: - # We only allow persons that have a non measurement record in the dataset - qualified_person_df = ( - patient_events.where(~F.col("domain").isin([MEASUREMENT, CATEGORICAL_MEASUREMENT])) - .where(F.col("standard_concept_id") != UNKNOWN_CONCEPT) - .select("person_id") - .distinct() - ) - - patient_events = patient_events.join(qualified_person_df, "person_id") - - # Create the udf for calculating the weeks since the epoch time 1970-01-01 - weeks_since_epoch_udf = (F.unix_timestamp("date") / F.lit(24 * 60 * 60 * 7)).cast("int") - - # UDF for creating the concept orders within each visit - visit_concept_order_udf = F.row_number().over( - W.partitionBy("cohort_member_id", "person_id", "visit_occurrence_id").orderBy("date", "standard_concept_id") - ) - - patient_events = ( - patient_events.withColumn("date", F.col("date").cast("date")) - .withColumn("date_in_week", weeks_since_epoch_udf) - .withColumn("visit_concept_order", visit_concept_order_udf) - ) - - # Insert a CLS token at the beginning of each visit, this CLS token will be used as the visit - # summary in pre-training / fine-tuning. We basically make a copy of the first concept of - # each visit and change it to CLS, and set the concept order to 0 to make sure this is always - # the first token of each visit - insert_cls_tokens = ( - patient_events.where("visit_concept_order == 1") - .withColumn("standard_concept_id", F.lit("CLS")) - .withColumn("domain", F.lit("CLS")) - .withColumn("visit_concept_order", F.lit(0)) - .withColumn("date", F.col("visit_start_date")) - .withColumn("concept_value_mask", F.lit(0)) - .withColumn("concept_value", F.lit(-1.0)) - .withColumn("mlm_skip", F.lit(1)) - .withColumn("condition_mask", F.lit(0)) - ) - - # Declare a list of columns that need to be collected per each visit - struct_columns = [ - "visit_concept_order", - "standard_concept_id", - "date_in_week", - "age", - "concept_value_mask", - "concept_value", - "mlm_skip", - "condition_mask", - ] - - # Merge the first CLS tokens into patient sequence and collect events for each visit - patent_visit_sequence = ( - patient_events.union(insert_cls_tokens) - .withColumn("visit_struct_data", F.struct(struct_columns)) - .groupBy("cohort_member_id", "person_id", "visit_occurrence_id") - .agg( - F.sort_array(F.collect_set("visit_struct_data")).alias("visit_struct_data"), - F.first("visit_start_date").alias("visit_start_date"), - F.first("visit_rank_order").alias("visit_rank_order"), - F.first("visit_concept_id").alias("visit_concept_id"), - F.first("is_readmission").alias("is_readmission"), - F.first("is_inpatient").alias("is_inpatient"), - F.first("visit_segment").alias("visit_segment"), - F.first("time_interval_att").alias("time_interval_att"), - F.first("prolonged_stay").alias("prolonged_stay"), - F.count("standard_concept_id").alias("num_of_concepts"), - ) - .orderBy(["person_id", "visit_rank_order"]) - ) - - patient_visit_sequence = ( - patent_visit_sequence.withColumn("visit_concept_orders", F.col("visit_struct_data.visit_concept_order")) - .withColumn("visit_concept_ids", F.col("visit_struct_data.standard_concept_id")) - .withColumn("visit_concept_dates", F.col("visit_struct_data.date_in_week")) - .withColumn("visit_concept_ages", F.col("visit_struct_data.age")) - .withColumn("concept_value_masks", F.col("visit_struct_data.concept_value_mask")) - .withColumn("concept_values", F.col("visit_struct_data.concept_value")) - .withColumn("mlm_skip_values", F.col("visit_struct_data.mlm_skip")) - .withColumn("condition_masks", F.col("visit_struct_data.condition_mask")) - .withColumn("visit_mask", F.lit(0)) - .drop("visit_struct_data") - ) - - visit_struct_data_columns = [ - "visit_rank_order", - "visit_occurrence_id", - "visit_start_date", - "visit_concept_id", - "prolonged_stay", - "visit_mask", - "visit_segment", - "num_of_concepts", - "is_readmission", - "is_inpatient", - "time_interval_att", - "visit_concept_orders", - "visit_concept_ids", - "visit_concept_dates", - "visit_concept_ages", - "concept_values", - "concept_value_masks", - "mlm_skip_values", - "condition_masks", - ] - - visit_weeks_since_epoch_udf = ( - F.unix_timestamp(F.col("visit_start_date").cast("date")) / F.lit(24 * 60 * 60 * 7) - ).cast("int") - - patient_sequence = ( - patient_visit_sequence.withColumn("visit_start_date", visit_weeks_since_epoch_udf) - .withColumn( - "visit_struct_data", - F.struct(visit_struct_data_columns).alias("visit_struct_data"), - ) - .groupBy("cohort_member_id", "person_id") - .agg( - F.sort_array(F.collect_list("visit_struct_data")).alias("patient_list"), - F.sum(F.lit(1) - F.col("visit_mask")).alias("num_of_visits"), - F.sum("num_of_concepts").alias("num_of_concepts"), - ) - ) - - if max_num_of_visits_per_person: - patient_sequence = patient_sequence.where(F.col("num_of_visits") <= max_num_of_visits_per_person) - - patient_sequence = ( - patient_sequence.withColumn("visit_rank_orders", F.col("patient_list.visit_rank_order")) - .withColumn("concept_orders", F.col("patient_list.visit_concept_orders")) - .withColumn("concept_ids", F.col("patient_list.visit_concept_ids")) - .withColumn("dates", F.col("patient_list.visit_concept_dates")) - .withColumn("ages", F.col("patient_list.visit_concept_ages")) - .withColumn("visit_dates", F.col("patient_list.visit_start_date")) - .withColumn("visit_segments", F.col("patient_list.visit_segment")) - .withColumn("visit_masks", F.col("patient_list.visit_mask")) - .withColumn( - "visit_concept_ids", - F.col("patient_list.visit_concept_id").cast(T.ArrayType(T.StringType())), - ) - .withColumn("time_interval_atts", F.col("patient_list.time_interval_att")) - .withColumn("concept_values", F.col("patient_list.concept_values")) - .withColumn("concept_value_masks", F.col("patient_list.concept_value_masks")) - .withColumn("mlm_skip_values", F.col("patient_list.mlm_skip_values")) - .withColumn("condition_masks", F.col("patient_list.condition_masks")) - .withColumn( - "is_readmissions", - F.col("patient_list.is_readmission").cast(T.ArrayType(T.IntegerType())), - ) - .withColumn( - "is_inpatients", - F.col("patient_list.is_inpatient").cast(T.ArrayType(T.IntegerType())), - ) - .withColumn( - "visit_prolonged_stays", - F.col("patient_list.prolonged_stay").cast(T.ArrayType(T.IntegerType())), - ) - .drop("patient_list") - ) - - return patient_sequence - - -def create_visit_person_join(person, visit_occurrence, include_incomplete_visit=True): - """ - Create a new spark data frame based on person and visit_occurrence. - - :param person: - :param visit_occurrence: - :param include_incomplete_visit: - :return: - """ - - # Create a pandas udf for generating the att token between two neighboring visits - @pandas_udf("string") - def pandas_udf_to_att(time_intervals: pd.Series) -> pd.Series: - return time_intervals.apply(time_token_func) - - visit_rank_udf = F.row_number().over( - W.partitionBy("person_id").orderBy("visit_start_date", "visit_end_date", "visit_occurrence_id") - ) - visit_segment_udf = F.col("visit_rank_order") % F.lit(2) + 1 - visit_windowing = W.partitionBy("person_id").orderBy("visit_start_date", "visit_end_date", "visit_occurrence_id") - # Check whehter or not the visit is either an inpatient visit or E-I visit - is_inpatient_logic = F.col("visit_concept_id").isin([9201, 262]).cast("integer") - # Construct the logic for readmission, which is defined as inpatient visit occurred within 30 - # days of the discharge - readmission_logic = F.coalesce( - ( - (F.col("time_interval") <= 30) - & (F.col("visit_concept_id").isin([9201, 262])) - & (F.col("prev_visit_concept_id").isin([9201, 262])) - ).cast("integer"), - F.lit(0), - ) - - # Create prolonged inpatient stay - # For the incomplete visit, we set prolonged_length_stay_logic to 0 - prolonged_length_stay_logic = F.coalesce( - (F.datediff("visit_end_date", "visit_start_date") >= 7).cast("integer"), - F.lit(0), - ) - - visit_filter = "visit_start_date IS NOT NULL" - if not include_incomplete_visit: - visit_filter = f"{visit_filter} AND visit_end_date IS NOT NULL" - - # Select the subset of columns and create derived columns using the UDF or spark sql - # functions. In addition, we allow visits where visit_end_date IS NOT NULL, indicating the - # visit is still on-going - visit_occurrence = ( - visit_occurrence.select( - "visit_occurrence_id", - "person_id", - "visit_concept_id", - "visit_start_date", - "visit_end_date", - ) - .where(visit_filter) - .withColumn("visit_rank_order", visit_rank_udf) - .withColumn("visit_segment", visit_segment_udf) - .withColumn( - "prev_visit_occurrence_id", - F.lag("visit_occurrence_id").over(visit_windowing), - ) - .withColumn("prev_visit_concept_id", F.lag("visit_concept_id").over(visit_windowing)) - .withColumn("prev_visit_start_date", F.lag("visit_start_date").over(visit_windowing)) - .withColumn("prev_visit_end_date", F.lag("visit_end_date").over(visit_windowing)) - .withColumn("time_interval", F.datediff("visit_start_date", "prev_visit_end_date")) - .withColumn( - "time_interval", - F.when(F.col("time_interval") < 0, F.lit(0)).otherwise(F.col("time_interval")), - ) - .withColumn("time_interval_att", pandas_udf_to_att("time_interval")) - .withColumn("is_inpatient", is_inpatient_logic) - .withColumn("is_readmission", readmission_logic) - ) - - visit_occurrence = visit_occurrence.withColumn("prolonged_stay", prolonged_length_stay_logic).select( - "visit_occurrence_id", - "visit_concept_id", - "person_id", - "prolonged_stay", - "is_readmission", - "is_inpatient", - "time_interval_att", - "visit_rank_order", - "visit_start_date", - "visit_segment", - ) - # Assume the birthday to be the first day of the birth year if birth_datetime is missing - person = person.select( - "person_id", - F.coalesce( - "birth_datetime", - F.concat("year_of_birth", F.lit("-01-01")).cast("timestamp"), - ).alias("birth_datetime"), - ) - return visit_occurrence.join(person, "person_id") - - -def process_measurement(spark, measurement, required_measurement, output_folder: str = None): - """ - Remove the measurement values that are outside the 0.01-0.99 quantiles. - - And scale the the - measurement value by substracting the mean and dividing by the standard deivation :param - - spark: :param - measurement: :param - required_measurement: - - :return: - """ - # Register the tables in spark context - measurement.createOrReplaceTempView(MEASUREMENT) - required_measurement.createOrReplaceTempView(REQUIRED_MEASUREMENT) - measurement_unit_stats_df = spark.sql(measurement_unit_stats_query) - - if output_folder: - measurement_unit_stats_df.repartition(10).write.mode("overwrite").parquet( - path.join(output_folder, "measurement_unit_stats") - ) - measurement_unit_stats_df = spark.read.parquet(path.join(output_folder, "measurement_unit_stats")) - - # Cache the stats in memory - measurement_unit_stats_df.cache() - # Broadcast df to local executors - broadcast(measurement_unit_stats_df) - # Create the temp view for this dataframe - measurement_unit_stats_df.createOrReplaceTempView("measurement_unit_stats") - - scaled_numeric_lab = spark.sql( - """ - SELECT - m.person_id, - m.measurement_concept_id AS standard_concept_id, - CAST(m.measurement_date AS DATE) AS date, - CAST(COALESCE(m.measurement_datetime, m.measurement_date) AS TIMESTAMP) AS datetime, - m.visit_occurrence_id, - 'measurement' AS domain, - (m.value_as_number - s.value_mean) / value_stddev AS concept_value - FROM measurement AS m - JOIN measurement_unit_stats AS s - ON s.measurement_concept_id = m.measurement_concept_id - AND s.unit_concept_id = m.unit_concept_id - WHERE m.visit_occurrence_id IS NOT NULL - AND m.value_as_number IS NOT NULL - AND m.value_as_number BETWEEN s.lower_bound AND s.upper_bound - """ - ) - - # For categorical measurements in required_measurement, we concatenate measurement_concept_id - # with value_as_concept_id to construct a new standard_concept_id - categorical_lab = spark.sql( - """ - SELECT - m.person_id, - CASE - WHEN value_as_concept_id IS NOT NULL AND value_as_concept_id <> 0 - THEN CONCAT(CAST(measurement_concept_id AS STRING), '-', CAST(value_as_concept_id AS STRING)) - ELSE CAST(measurement_concept_id AS STRING) - END AS standard_concept_id, - CAST(m.measurement_date AS DATE) AS date, - CAST(COALESCE(m.measurement_datetime, m.measurement_date) AS TIMESTAMP) AS datetime, - m.visit_occurrence_id, - 'categorical_measurement' AS domain, - -1.0 AS concept_value - FROM measurement AS m - WHERE EXISTS ( - SELECT - 1 - FROM required_measurement AS r - WHERE r.measurement_concept_id = m.measurement_concept_id - AND r.is_numeric = false - ) - """ - ) - - processed_measurement_df = scaled_numeric_lab.unionAll(categorical_lab) - - if output_folder: - processed_measurement_df.write.mode("overwrite").parquet(path.join(output_folder, "processed_measurement")) - processed_measurement_df = spark.read.parquet(path.join(output_folder, "processed_measurement")) - - return processed_measurement_df - - -def get_mlm_skip_domains(spark, input_folder, mlm_skip_table_list): - """ - Translate the domain_table_name to the domain name. - - :param spark: - :param input_folder: - :param mlm_skip_table_list: - :return: - """ - domain_tables = [ - preprocess_domain_table(spark, input_folder, domain_table_name) for domain_table_name in mlm_skip_table_list - ] - - return list(map(get_domain_field, domain_tables)) - - -def validate_table_names(domain_names): - for domain_name in domain_names.split(" "): - if domain_name not in CDM_TABLES: - raise argparse.ArgumentTypeError(f"{domain_name} is an invalid CDM table name") - return domain_names diff --git a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py index f0d76a7b..1f0af1f3 100644 --- a/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py +++ b/tests/unit_tests/data_generators/hf_data_generator/hf_med_to_cehrbert_mapping_test.py @@ -1,10 +1,11 @@ import unittest from datetime import datetime +from cehrbert_data.decorators.patient_event_decorator import AttType + from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments -from cehrbert.spark_apps.decorators.patient_event_decorator import AttType # Actual test class