From 5e2e99b2d525e40a396e6b4833ad230b0841539a Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 5 Sep 2024 14:44:34 -0400 Subject: [PATCH 1/5] Create pylint.yml --- .github/workflows/pylint.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/pylint.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 00000000..545ed1fe --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') From 43fdc4aa4f0b2f087fe7a60710eb77c2e9e83dd6 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 5 Sep 2024 14:48:51 -0400 Subject: [PATCH 2/5] Fixed the pylint error about the UPPER_CASE naming style --- src/cehrbert/config/output_names.py | 18 +++++++++--------- src/cehrbert/evaluations/evaluation.py | 16 ++++++++-------- .../generate_concept_similarity_table.py | 4 ++-- ...generate_hierarchical_bert_training_data.py | 4 ++-- .../generate_included_concept_list.py | 4 ++-- .../spark_apps/generate_information_content.py | 2 +- src/cehrbert/utils/spark_utils.py | 4 ++-- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/cehrbert/config/output_names.py b/src/cehrbert/config/output_names.py index 2d26aae2..ba931b55 100644 --- a/src/cehrbert/config/output_names.py +++ b/src/cehrbert/config/output_names.py @@ -1,9 +1,9 @@ -parquet_data_path = 'patient_sequence' -qualified_concept_list_path = 'qualified_concept_list' -time_attention_model_path = 'time_aware_model.h5' -bert_model_validation_path = 'bert_model.h5' -mortality_data_path = 'mortality' -heart_failure_data_path = 'heart_failure' -hospitalization_data_path = 'hospitalization' -information_content_data_path = 'information_content' -concept_similarity_path = 'concept_similarity' +PARQUET_DATA_PATH = 'patient_sequence' +QUALIFIED_CONCEPT_LIST_PATH = 'qualified_concept_list' +TIME_ATTENTION_MODEL_PATH = 'time_aware_model.h5' +BERT_MODEL_VALIDATION_PATH = 'bert_model.h5' +MORTALITY_DATA_PATH = 'mortality' +HEART_FAILURE_DATA_PATH = 'heart_failure' +HOSPITALIZATION_DATA_PATH = 'hospitalization' +INFORMATION_CONTENT_DATA_PATH = 'information_content' +CONCEPT_SIMILARITY_PATH = 'concept_similarity' diff --git a/src/cehrbert/evaluations/evaluation.py b/src/cehrbert/evaluations/evaluation.py index d1cfe966..e407338c 100644 --- a/src/cehrbert/evaluations/evaluation.py +++ b/src/cehrbert/evaluations/evaluation.py @@ -56,7 +56,7 @@ def evaluate_sequence_models(args): time_attention_tokenizer_path = find_tokenizer_path(args.time_attention_model_folder) time_aware_model_path = os.path.join( args.time_attention_model_folder, - p.time_attention_model_path + p.TIME_ATTENTION_MODEL_PATH ) BiLstmModelEvaluator( dataset=dataset, @@ -83,7 +83,7 @@ def evaluate_sequence_models(args): validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) BertFeedForwardModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -108,7 +108,7 @@ def evaluate_sequence_models(args): validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) SlidingBertModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -134,7 +134,7 @@ def evaluate_sequence_models(args): validate_folder(args.vanilla_bert_model_folder) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) BertLstmModelEvaluator( dataset=dataset, evaluation_folder=args.evaluation_folder, @@ -160,7 +160,7 @@ def evaluate_sequence_models(args): if RANDOM_VANILLA_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -195,7 +195,7 @@ def evaluate_sequence_models(args): if HIERARCHICAL_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -227,7 +227,7 @@ def evaluate_sequence_models(args): if HIERARCHICAL_BERT_POOLING in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) @@ -259,7 +259,7 @@ def evaluate_sequence_models(args): if RANDOM_HIERARCHICAL_BERT_LSTM in args.model_evaluators: validate_folder(args.vanilla_bert_model_folder) bert_model_path = os.path.join(args.vanilla_bert_model_folder, - p.bert_model_validation_path) + p.BERT_MODEL_VALIDATION_PATH) bert_tokenizer_path = find_tokenizer_path(args.vanilla_bert_model_folder) bert_visit_tokenizer_path = find_visit_tokenizer_path(args.vanilla_bert_model_folder) diff --git a/src/cehrbert/spark_apps/generate_concept_similarity_table.py b/src/cehrbert/spark_apps/generate_concept_similarity_table.py index 3448b0ee..dbe991e0 100644 --- a/src/cehrbert/spark_apps/generate_concept_similarity_table.py +++ b/src/cehrbert/spark_apps/generate_concept_similarity_table.py @@ -307,7 +307,7 @@ def main( preprocess_domain_table( spark, input_folder, - qualified_concept_list_path + QUALIFIED_CONCEPT_LIST_PATH ) ) @@ -350,7 +350,7 @@ def main( concept_pair_similarity.write.mode('overwrite').parquet( os.path.join( output_folder, - concept_similarity_path + CONCEPT_SIMILARITY_PATH ) ) diff --git a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py b/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py index 318a3831..ba7c89b9 100644 --- a/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py +++ b/src/cehrbert/spark_apps/generate_hierarchical_bert_training_data.py @@ -73,7 +73,7 @@ def main( preprocess_domain_table( spark, input_folder, - qualified_concept_list_path + QUALIFIED_CONCEPT_LIST_PATH ) ) # The select is necessary to make sure the order of the columns is the same as the @@ -116,7 +116,7 @@ def main( sequence_data.write.mode('overwrite').parquet( os.path.join( output_folder, - parquet_data_path + PARQUET_DATA_PATH ) ) diff --git a/src/cehrbert/spark_apps/generate_included_concept_list.py b/src/cehrbert/spark_apps/generate_included_concept_list.py index 61da4fe1..295ec646 100644 --- a/src/cehrbert/spark_apps/generate_included_concept_list.py +++ b/src/cehrbert/spark_apps/generate_included_concept_list.py @@ -5,7 +5,7 @@ from ..utils.spark_utils import * from ..const.common import MEASUREMENT -from ..config.output_names import qualified_concept_list_path +from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH DOMAIN_TABLE_LIST = ['condition_occurrence', 'procedure_occurrence', 'drug_exposure'] @@ -39,7 +39,7 @@ def main( qualified_concepts.write.mode('overwrite').parquet( os.path.join( output_folder, - qualified_concept_list_path + QUALIFIED_CONCEPT_LIST_PATH ) ) diff --git a/src/cehrbert/spark_apps/generate_information_content.py b/src/cehrbert/spark_apps/generate_information_content.py index 4ae3110e..801cdc02 100644 --- a/src/cehrbert/spark_apps/generate_information_content.py +++ b/src/cehrbert/spark_apps/generate_information_content.py @@ -60,7 +60,7 @@ def main( .withColumn('probability', F.col('count') / total_count) information_content.write.mode('overwrite').parquet( - os.path.join(output_folder, information_content_data_path) + os.path.join(output_folder, INFORMATION_CONTENT_DATA_PATH) ) diff --git a/src/cehrbert/utils/spark_utils.py b/src/cehrbert/utils/spark_utils.py index 180c4eef..7cdb60e8 100644 --- a/src/cehrbert/utils/spark_utils.py +++ b/src/cehrbert/utils/spark_utils.py @@ -9,7 +9,7 @@ from pyspark.sql.functions import broadcast from pyspark.sql.pandas.functions import pandas_udf -from ..config.output_names import qualified_concept_list_path +from ..config.output_names import QUALIFIED_CONCEPT_LIST_PATH from ..const.common import PERSON, VISIT_OCCURRENCE, UNKNOWN_CONCEPT, MEASUREMENT, \ CATEGORICAL_MEASUREMENT, REQUIRED_MEASUREMENT, CDM_TABLES from ..spark_apps.decorators.patient_event_decorator import ( @@ -639,7 +639,7 @@ def extract_ehr_records(spark, input_folder, domain_table_list, include_visit_ty qualified_concepts = preprocess_domain_table( spark, input_folder, - qualified_concept_list_path + QUALIFIED_CONCEPT_LIST_PATH ).select('standard_concept_id') patient_ehr_records = patient_ehr_records.join( From bdbd28114e43396d8beb5c21aeb053ba1b72c232 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 5 Sep 2024 14:57:15 -0400 Subject: [PATCH 3/5] added python-app to run the tests when the main branch is pushed or a PR is approved --- .github/workflows/pylint.yml | 23 - .github/workflows/python-app.yml | 39 ++ .../models/hierachical_phenotype_model.py | 453 ------------------ 3 files changed, 39 insertions(+), 476 deletions(-) delete mode 100644 .github/workflows/pylint.yml create mode 100644 .github/workflows/python-app.yml delete mode 100644 src/cehrbert/models/hierachical_phenotype_model.py diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml deleted file mode 100644 index 545ed1fe..00000000 --- a/.github/workflows/pylint.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Pylint - -on: [push] - -jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.9", "3.10", "3.11"] - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install pylint - - name: Analysing the code with pylint - run: | - pylint $(git ls-files '*.py') diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 00000000..f456de63 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10.0 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + pip install -e . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + PYTHONPATH=./: pytest \ No newline at end of file diff --git a/src/cehrbert/models/hierachical_phenotype_model.py b/src/cehrbert/models/hierachical_phenotype_model.py deleted file mode 100644 index 9eac4378..00000000 --- a/src/cehrbert/models/hierachical_phenotype_model.py +++ /dev/null @@ -1,453 +0,0 @@ -from .layers.custom_layers import * -from .hierachical_bert_model_v2 import create_att_concept_mask - - -def create_probabilistic_phenotype_model( - num_of_visits, - num_of_concepts, - concept_vocab_size, - embedding_size, - depth: int, - num_heads: int, - transformer_dropout: float = 0.1, - embedding_dropout: float = 0.6, - l2_reg_penalty: float = 1e-4, - time_embeddings_size: int = 16, - include_att_prediction: bool = False, - include_visit_prediction: bool = False, - include_readmission: bool = False, - include_prolonged_length_stay: bool = False, - visit_vocab_size: int = None, - num_of_phenotypes: int = 20, - num_of_phenotype_neighbors: int = 3, - num_of_concept_neighbors: int = 10 -): - """ - Create a hierarchical bert model - - - :param num_of_visits: - :param num_of_concepts: - :param concept_vocab_size: - :param embedding_size: - :param depth: - :param num_heads: - :param transformer_dropout: - :param embedding_dropout: - :param l2_reg_penalty: - :param time_embeddings_size: - :param include_att_prediction: - :param include_visit_prediction: - :param include_readmission: - :param include_prolonged_length_stay: - :param visit_vocab_size: - :param num_of_phenotypes: - :param num_of_phenotype_neighbors: - :param num_of_concept_neighbors: - :return: - """ - # If the second tiered learning objectives are enabled, visit_vocab_size needs to be provided - if include_visit_prediction and not visit_vocab_size: - raise RuntimeError(f'visit_vocab_size can not be null ' - f'when the second learning objectives are enabled') - - pat_seq = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq' - ) - pat_seq_age = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq_age' - ) - pat_seq_time = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_seq_time' - ) - pat_mask = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='pat_mask' - ) - concept_values = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='float32', - name='concept_values' - ) - concept_value_masks = tf.keras.layers.Input( - shape=(num_of_visits, num_of_concepts,), - dtype='int32', - name='concept_value_masks' - ) - visit_mask = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='visit_mask') - - visit_time_delta_att = tf.keras.layers.Input( - shape=(num_of_visits - 1,), - dtype='int32', - name='visit_time_delta_att' - ) - visit_rank_order = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='visit_rank_order' - ) - visit_visit_type = tf.keras.layers.Input( - shape=(num_of_visits,), - dtype='int32', - name='masked_visit_type' - ) - - # Create a list of inputs so the model could reference these later - default_inputs = [pat_seq, pat_seq_age, pat_seq_time, pat_mask, - concept_values, concept_value_masks, visit_mask, - visit_time_delta_att, visit_rank_order, visit_visit_type] - - # Expand dimensions for masking MultiHeadAttention in Concept Encoder - pat_concept_mask = tf.reshape( - pat_mask, - shape=(-1, num_of_concepts) - )[:, tf.newaxis, tf.newaxis, :] - - # output the embedding_matrix: - l2_regularizer = (tf.keras.regularizers.l2(l2_reg_penalty) if l2_reg_penalty else None) - concept_embedding_layer = ReusableEmbedding( - concept_vocab_size, - embedding_size, - name='concept_embedding_layer', - embeddings_regularizer=l2_regularizer - ) - - visit_type_embedding_layer = ReusableEmbedding( - concept_vocab_size, - embedding_size, - name='visit_type_embedding_layer', - embeddings_regularizer=l2_regularizer - ) - - # Look up the embeddings for the concepts - concept_embeddings, embedding_matrix = concept_embedding_layer( - pat_seq - ) - - concept_value_transformation_layer = ConceptValueTransformationLayer( - embedding_size=embedding_size, - name='concept_value_transformation_layer' - ) - - # Transform the concept embeddings by combining their concept embeddings with the - # corresponding val - concept_embeddings = concept_value_transformation_layer( - concept_embeddings=concept_embeddings, - concept_values=concept_values, - concept_value_masks=concept_value_masks - ) - - # Look up the embeddings for the att tokens - att_embeddings, _ = concept_embedding_layer( - visit_time_delta_att - ) - - # Re-purpose token id 0 as the visit start embedding - visit_start_embeddings, _ = concept_embedding_layer( - tf.zeros_like( - visit_mask, - dtype=tf.int32 - ) - ) - - temporal_transformation_layer = TemporalTransformationLayer( - time_embeddings_size=time_embeddings_size, - embedding_size=embedding_size, - name='temporal_transformation_layer' - ) - - # (batch, num_of_visits, num_of_concepts, embedding_size) - concept_embeddings = temporal_transformation_layer( - concept_embeddings, - pat_seq_age, - pat_seq_time, - visit_rank_order - ) - - # (batch, num_of_visits, embedding_size) - # The first bert applied at the visit level - concept_encoder = Encoder( - name='concept_encoder', - num_layers=depth, - d_model=embedding_size, - num_heads=num_heads, - dropout_rate=transformer_dropout - ) - - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_concepts, embedding_size) - ) - - concept_embeddings, _ = concept_encoder( - concept_embeddings, # be reused - pat_concept_mask # not change - ) - - # (batch_size, num_of_visits, num_of_concepts, embedding_size) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_visits, num_of_concepts, embedding_size) - ) - - # Step 2 generate visit embeddings - # Slice out the first contextualized embedding of each visit - # (batch_size, num_of_visits, embedding_size) - visit_embeddings = concept_embeddings[:, :, 0] - - visit_type_embedding_dense_layer = tf.keras.layers.Dense( - embedding_size, - name='visit_type_embedding_dense_layer' - ) - - # (batch_size, num_of_visits, embedding_size) - visit_type_embeddings, visit_type_embedding_matrix = visit_type_embedding_layer( - visit_visit_type - ) - - # Combine visit_type_embeddings with visit_embeddings - visit_embeddings = visit_type_embedding_dense_layer( - tf.concat([ - visit_embeddings, - visit_type_embeddings - ], axis=-1) - ) - - # (batch_size, num_of_visits, embedding_size) - expanded_att_embeddings = tf.concat([att_embeddings, att_embeddings[:, 0:1, :]], axis=1) - - # Insert the att embeddings between visit embeddings - # (batch_size, num_of_visits + num_of_visits + num_of_visits - 1, embedding_size) - contextualized_visit_embeddings = tf.reshape( - tf.concat( - [visit_start_embeddings, - visit_embeddings, - expanded_att_embeddings], - axis=-1 - ), - (-1, 3 * num_of_visits, embedding_size) - )[:, :-1, :] - - # Expand dimension for masking MultiHeadAttention in Visit Encoder - visit_mask_with_att = tf.reshape( - tf.tile(visit_mask[:, :, tf.newaxis], [1, 1, 3]), - (-1, num_of_visits * 3) - )[:, tf.newaxis, tf.newaxis, 1:] - - # (num_of_visits_with_att, num_of_visits_with_att) - look_ahead_mask_base = tf.cast( - tf.linalg.band_part(tf.ones((num_of_visits, num_of_visits)), -1, 0), - dtype=tf.int32 - ) - look_ahead_visit_mask_with_att = tf.reshape( - tf.tile( - look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, 3, 1, 3] - ), - shape=(num_of_visits * 3, num_of_visits * 3) - )[:-1, :-1] - - look_ahead_concept_mask = tf.reshape( - tf.tile( - look_ahead_mask_base[:, tf.newaxis, :, tf.newaxis], - [1, num_of_concepts, 1, 1] - ), - (num_of_concepts * num_of_visits, -1) - ) - - # (batch_size, 1, num_of_visits_with_att, num_of_visits_with_att) - look_ahead_visit_mask_with_att = tf.minimum( - visit_mask_with_att, - look_ahead_visit_mask_with_att - ) - - # (batch_size, 1, num_of_visits * num_of_concepts, num_of_visits) - look_ahead_concept_mask = tf.maximum( - visit_mask[:, tf.newaxis, tf.newaxis, :], - look_ahead_concept_mask - ) - - # Second bert applied at the patient level to the visit embeddings - visit_encoder = Encoder( - name='visit_encoder', - num_layers=depth, - d_model=embedding_size, - num_heads=num_heads, - dropout_rate=transformer_dropout - ) - - # Feed augmented visit embeddings into encoders to get contextualized visit embeddings - contextualized_visit_embeddings, _ = visit_encoder( - contextualized_visit_embeddings, - look_ahead_visit_mask_with_att - ) - - # Pad contextualized_visit_embeddings on axis 1 with one extra visit so we can extract the - # visit embeddings using the reshape trick - expanded_contextualized_visit_embeddings = tf.concat( - [contextualized_visit_embeddings, - contextualized_visit_embeddings[:, 0:1, :]], - axis=1 - ) - - # Extract the visit embeddings elements - visit_embeddings_without_att = tf.reshape( - expanded_contextualized_visit_embeddings, (-1, num_of_visits, 3 * embedding_size) - )[:, :, embedding_size: embedding_size * 2] - - # Step 4: Assuming there is a generative process that generates diagnosis embeddings from a - # Multivariate Gaussian Distribution Declare phenotype distribution prior - visit_phenotype_layer = VisitPhenotypeLayer( - num_of_phenotypes=num_of_phenotypes, - num_of_phenotype_neighbors=num_of_phenotype_neighbors, - num_of_concept_neighbors=num_of_concept_neighbors, - embedding_size=embedding_size, - transformer_dropout=transformer_dropout, - name='hidden_visit_embeddings' - ) - - # (batch_size, num_of_visits, vocab_size) - visit_embeddings_without_att, _, = visit_phenotype_layer( - [visit_embeddings_without_att, visit_mask, embedding_matrix] - ) - - # # Step 3 decoder applied to patient level - # Reshape the data in visit view back to patient view: - # (batch, num_of_visits * num_of_concepts, embedding_size) - concept_embeddings = tf.reshape( - concept_embeddings, - shape=(-1, num_of_visits * num_of_concepts, embedding_size) - ) - - # Let local concept embeddings access the global representatives of each visit - global_concept_embeddings_layer = SimpleDecoderLayer( - d_model=embedding_size, - num_heads=num_heads, - rate=transformer_dropout, - dff=512, - name='global_concept_embeddings_layer' - ) - - global_concept_embeddings, _ = global_concept_embeddings_layer( - concept_embeddings, - visit_embeddings_without_att, - look_ahead_concept_mask - ) - - concept_output_layer = TiedOutputEmbedding( - projection_regularizer=l2_regularizer, - projection_dropout=embedding_dropout, - name='concept_prediction_logits') - - concept_softmax_layer = tf.keras.layers.Softmax( - name='concept_predictions' - ) - - concept_predictions = concept_softmax_layer( - concept_output_layer([global_concept_embeddings, embedding_matrix]) - ) - - outputs = [concept_predictions] - - if include_att_prediction: - # Extract the ATT embeddings - contextualized_att_embeddings = tf.reshape( - expanded_contextualized_visit_embeddings, (-1, num_of_visits, 3 * embedding_size) - )[:, :-1, embedding_size * 2:] - - # Create the att to concept mask ATT tokens only attend to the concepts in the - # neighboring visits - att_concept_mask = create_att_concept_mask( - num_of_concepts, - num_of_visits, - visit_mask - ) - - # Use the simple decoder layer to decode att embeddings using the neighboring concept - # embeddings - global_att_embeddings_layer = SimpleDecoderLayer( - d_model=embedding_size, - num_heads=num_heads, - rate=transformer_dropout, - dff=512, - name='global_att_embeddings_layer' - ) - - contextualized_att_embeddings, _ = global_att_embeddings_layer( - contextualized_att_embeddings, - concept_embeddings, - att_concept_mask - ) - - att_prediction_layer = tf.keras.layers.Softmax( - name='att_predictions', - ) - - att_predictions = att_prediction_layer( - concept_output_layer([contextualized_att_embeddings, embedding_matrix]) - ) - outputs.append(att_predictions) - - if include_visit_prediction: - # Slice out the visit embeddings (CLS tokens) - - visit_type_prediction_output_layer = TiedOutputEmbedding( - projection_regularizer=l2_regularizer, - projection_dropout=embedding_dropout, - name='visit_type_prediction_logits' - ) - - visit_softmax_layer = tf.keras.layers.Softmax( - name='visit_predictions' - ) - - visit_predictions = visit_softmax_layer( - visit_type_prediction_output_layer( - [visit_embeddings_without_att, visit_type_embedding_matrix] - ) - ) - - outputs.append(visit_predictions) - - if include_readmission: - is_readmission_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='is_readmission' - ) - - is_readmission_output = is_readmission_layer( - visit_embeddings_without_att - ) - - outputs.append(is_readmission_output) - - if include_prolonged_length_stay: - visit_prolonged_stay_layer = tf.keras.layers.Dense( - 1, - activation='sigmoid', - name='visit_prolonged_stay' - ) - - visit_prolonged_stay_output = visit_prolonged_stay_layer( - visit_embeddings_without_att - ) - - outputs.append(visit_prolonged_stay_output) - - hierarchical_bert = tf.keras.Model( - inputs=default_inputs, - outputs=outputs - ) - - return hierarchical_bert From b79abd96ca9f2f284033b58c9d9b924acd5d52f5 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 5 Sep 2024 15:07:56 -0400 Subject: [PATCH 4/5] Removed unused dependencies from the project and updated README.md --- README.md | 3 ++- pyproject.toml | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f103fcf3..4ad98d22 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,8 @@ tar -xvf omop_synthea.tar . ``` Convert the OMOP dataset to the MEDS format ```console -meds_etl_omop omop_synthea synthea_meds +pip install meds_etl==0.3.6; +meds_etl_omop omop_synthea synthea_meds; ``` Convert MEDS to the meds_reader database to get the patient level data ```console diff --git a/pyproject.toml b/pyproject.toml index 915a0683..45c9f68b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,6 @@ dependencies = [ "dask==2024.1.1", "dask[dataframe]==2024.1.1", "datasets==2.16.1", - "docarray==0.40.0", - "docarray[hnswlib]==0.40.0", - "docarray[weaviate]==0.40.0", "evaluate==0.4.1", "fast-ml==3.68", "fastparquet==0.8.1", @@ -49,7 +46,7 @@ dependencies = [ "PyYAML==6.0.1", "scikit-learn==1.4.0", "scipy==1.12.0", - "tensorflow==2.15.0", + "tensorflow==2.12.0", "tensorflow-metal==1.1.0; sys_platform == 'darwin'", # macOS only "tensorflow-datasets==4.5.2", "tqdm==4.66.1", @@ -58,7 +55,6 @@ dependencies = [ "transformers==4.39.3", "Werkzeug==3.0.1", "wandb==0.17.8", - "Whoosh==2.7.4", "xgboost==2.0.3" ] From eac75868fc35d43223673492493e6cdb884d9a70 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Thu, 5 Sep 2024 15:10:13 -0400 Subject: [PATCH 5/5] Try to fix the tensorflow reshaping error --- pyproject.toml | 2 +- src/cehrbert/models/layers/custom_layers.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45c9f68b..d49e3e7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "PyYAML==6.0.1", "scikit-learn==1.4.0", "scipy==1.12.0", - "tensorflow==2.12.0", + "tensorflow==2.15.0", "tensorflow-metal==1.1.0; sys_platform == 'darwin'", # macOS only "tensorflow-datasets==4.5.2", "tqdm==4.66.1", diff --git a/src/cehrbert/models/layers/custom_layers.py b/src/cehrbert/models/layers/custom_layers.py index a0bc19df..b3ded591 100644 --- a/src/cehrbert/models/layers/custom_layers.py +++ b/src/cehrbert/models/layers/custom_layers.py @@ -211,11 +211,11 @@ def get_config(self): def call(self, x, enc_output, decoder_mask, encoder_mask, **kwargs): # The reason we are doing this is that tensorflow on Mac doesn't seem to recognize the rank correctly - if platform.system() == 'Darwin': - batch, length = tf.shape(x)[0], tf.shape(x)[1] - x = tf.reshape(x, (batch, -1, self.d_model)) - decoder_mask = tf.reshape(decoder_mask, (batch, -1, length)) - encoder_mask = tf.reshape(encoder_mask, (batch, -1, length)) + # if platform.system() == 'Darwin': + batch, length = tf.shape(x)[0], tf.shape(x)[1] + x = tf.reshape(x, (batch, -1, self.d_model)) + decoder_mask = tf.reshape(decoder_mask, (batch, -1, length)) + encoder_mask = tf.reshape(encoder_mask, (batch, -1, length)) # enc_output.shape == (batch_size, input_seq_len, d_model) attn1, attn_weights_block1 = self.mha1(