-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
190 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,159 @@ | ||
# Hierarchical CEHR-BERT | ||
This project is the continuation of the CEHR-BERT work, which has been published at https://proceedings.mlr.press/v158/pang21a.html. | ||
# CEHR-BERT | ||
|
||
## CEHR-BERT Architecture | ||
Hierarchical CEHR-BERT currently only supports the structured EHR data in the OMOP format, which is a common data model used to support observational studies and managed by the Observational Health Data Science and Informatics (OHDSI) open-science community. | ||
There are three major components in CEHR-BERT, data generation, model pre-training, and model evaluation with fine-tuning, those components work in conjunction to provide an end-to-end model evaluation framework. | ||
CEHR-BERT is a large language model developed for the structured EHR data, the work has been published | ||
at https://proceedings.mlr.press/v158/pang21a.html. CEHR-BERT currently only supports the structured EHR data in the | ||
OMOP format, which is a common data model used to support observational studies and managed by the Observational Health | ||
Data Science and Informatics (OHDSI) open-science community. | ||
There are three major components in CEHR-BERT, data generation, model pre-training, and model evaluation with | ||
fine-tuning, those components work in conjunction to provide an end-to-end model evaluation framework. The CEHR-BERT | ||
framework is designed to be extensible, users could write their | ||
own [pretraining models](trainers/README.md), [evaluation procedures](evaluations/README.md), | ||
and [downstream prediction tasks](spark_apps/README.md) by extending the abstract classes, see click on the links for | ||
more details. For a quick start, navigate to the [Get Started](#getting-started) section. | ||
|
||
## Getting Started | ||
## Patient Representation | ||
|
||
For each patient, all medical codes were aggregated and constructed into a sequence chronologically. | ||
In order to incorporate temporal information, we inserted an artificial time token (ATT) between two neighboring visits | ||
based on their time interval. | ||
The following logic was used for creating ATTs based on the following time intervals between visits, if less than 28 | ||
days, ATTs take on the form of $W_n$ where n represents the week number ranging from 0-3 (e.g. $W_1$); 2) if between 28 | ||
days and 365 days, ATTs are in the form of **$M_n$** where n represents the month number ranging from 1-11 e.g $M_{11}$; | ||
|
||
### Pre-requisite | ||
The project is built in python 3.7, and project dependency needs to be installed | ||
3) beyond 365 days then a **LT** (Long Term) token is inserted. In addition, we added two more special tokens — **VS** | ||
and **VE** to represent the start and the end of a visit to explicitly define the visit segment, where all the | ||
concepts | ||
associated with the visit are subsumed by **VS** and **VE**. | ||
|
||
`pip3 install -r requirements.txt` | ||
!["patient_representation"](images/tokenization_att_generation.png) | ||
|
||
## Model Architecture | ||
|
||
Overview of our BERT architecture on structured EHR data. To distinguish visit boundaries, visit segment embeddings are | ||
added to concept embeddings. Next, both visit embeddings and concept embeddings go through a temporal transformation, | ||
where concept, age and time embeddings are concatenated together. The concatenated embeddings are then fed into a fully | ||
connected layer. This temporal concept embedding becomes the input to BERT. We used the BERT learning objective Masked | ||
Language Model as the primary learning objective and introduced an EHR specific secondary learning objective visit type | ||
prediction. | ||
|
||
!["cehr-bert architecture diagram"](images/cehr_bert_architecture.png) | ||
|
||
## Pre-requisite | ||
|
||
The project is built in python 3.10, and project dependency needs to be installed | ||
|
||
Create a new Python virtual environment | ||
|
||
Create the following folders for the tutorial below | ||
```console | ||
mkdir -p ~/Documents/omop_test/hierarchical_bert; | ||
mkdir -p ~/Documents/omop_test/cehr-bert; | ||
python3.10 -m venv .venv; | ||
source .venv/bin/activate; | ||
``` | ||
|
||
### 1. Download OMOP tables as parquet files | ||
We have created a spark app to download OMOP tables from Sql Server as parquet files. You need adjust the properties in `db_properties.ini` to match with your database setup. | ||
Build the project | ||
|
||
```console | ||
pip install -e .[dev] | ||
``` | ||
|
||
Download [jtds-1.3.1.jar](jtds-1.3.1.jar) into the spark jars folder in the python environment | ||
```console | ||
cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/ | ||
``` | ||
|
||
## Instructions for Use with [MEDS](https://github.com/Medical-Event-Data-Standard/meds) | ||
|
||
### 1. Convert MEDS to the [meds_reader](https://github.com/som-shahlab/meds_reader) database | ||
|
||
If you don't have the MEDS dataset, you could convert the OMOP dataset to the MEDS | ||
using [meds_etl](https://github.com/Medical-Event-Data-Standard/meds_etl). | ||
We have prepared a synthea dataset with 1M patients for you to test, you could download it | ||
at [omop_synthea.tar.gz](https://drive.google.com/file/d/1k7-cZACaDNw8A1JRI37mfMAhEErxKaQJ/view?usp=share_link) | ||
```console | ||
tar -xvf omop_synthea.tar . | ||
``` | ||
Convert the OMOP dataset to the MEDS format | ||
```console | ||
meds_etl_omop omop_synthea synthea_meds | ||
``` | ||
Convert MEDS to the meds_reader database to get the patient level data | ||
```console | ||
spark-submit tools/download_omop_tables.py -c db_properties.ini -tc person visit_occurrence condition_occurrence procedure_occurrence drug_exposure measurement observation_period concept concept_relationship concept_ancestor -o ~/Documents/omop_test/ | ||
meds_reader_convert synthea_meds synthea_meds_reader --num_threads 4 | ||
``` | ||
### 2. Generate training data for Hierarchical CEHR-BERT | ||
#### Hierarchical training data | ||
This approach views the patient history as a list of visits, where each visit is a group of medical events. We order the patient events in chronological order and construct a list of lists, where the sublist contains all medical events associated with the same visit | ||
### 2. Pretrain CEHR-BERT using the meds_reader database | ||
```console | ||
spark-submit spark_apps/generate_hierarchical_bert_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/hierarchical_bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 | ||
mkdir test_dataset_prepared; | ||
mkdir test_synthea_results; | ||
python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml | ||
``` | ||
#### CEHR-BERT training data | ||
We order the patient events in chronological order and put all data points in a sequence. This approach allows us to apply BERT to structured EHR as-is. | ||
|
||
## Instructions for Use with OMOP | ||
|
||
### 1. Download OMOP tables as parquet files | ||
|
||
We created a spark app to download OMOP tables from SQL Server as parquet files. You need adjust the properties | ||
in `db_properties.ini` to match with your database setup. | ||
|
||
```console | ||
spark-submit spark_apps/generate_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 --is_new_patient_representation -iv | ||
PYTHONPATH=./: spark-submit tools/download_omop_tables.py -c db_properties.ini -tc person visit_occurrence condition_occurrence procedure_occurrence drug_exposure measurement observation_period concept concept_relationship concept_ancestor -o ~/Documents/omop_test/ | ||
``` | ||
|
||
### 3. Pre-train Hierarchical CEHR-BERT | ||
#### Train Hierarchical CEHR-BERT | ||
We have prepared a synthea dataset with 1M patients for you to test, you could download it | ||
at [omop_synthea.tar.gz](https://drive.google.com/file/d/1k7-cZACaDNw8A1JRI37mfMAhEErxKaQJ/view?usp=share_link) | ||
|
||
```console | ||
PYTHONPATH=./: python3 trainers/train_probabilistic_phenotype.py -i ~/Documents/omop_test/hierarchical_bert -o ~/Documents/omop_test/hierarchical_bert -b 32 --max_num_visits 20 --max_num_concepts 50 -e 1 -d 2 -iv --include_att_prediction --include_readmission --num_of_phenotypes 100 --num_of_concept_neighbors 160 | ||
tar -xvf omop_synthea.tar ~/Document/omop_test/ | ||
``` | ||
|
||
#### Train CEHR-BERT | ||
### 2. Generate training data for CEHR-BERT | ||
|
||
We order the patient events in chronological order and put all data points in a sequence. We insert artificial tokens | ||
VS (visit start) and VE (visit end) to the start and the end of the visit. In addition, we insert artificial time | ||
tokens (ATT) between visits to indicate the time interval between visits. This approach allows us to apply BERT to | ||
structured EHR as-is. | ||
The sequence can be seen conceptually as [VS] [V1] [VE] [ATT] [VS] [V2] [VE], where [V1] and [V2] represent a list of | ||
concepts associated with those visits. | ||
|
||
```console | ||
PYTHONPATH=./: python3 trainers/train_bert_only.py -i ~/Documents/omop_test/cehr-bert -o ~/Documents/omop_test/cehr-bert -iv -m 512 -e 1 -b 32 -d 5 | ||
PYTHONPATH=./: spark-submit spark_apps/generate_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 --is_new_patient_representation -iv | ||
``` | ||
### 4. Generate hf readmission prediction task | ||
#### Generate hf_readmission prediction data for Hierarchical CEHR-BERT | ||
|
||
### 3. Pre-train CEHR-BERT | ||
If you don't have your own OMOP instance, we have provided a sample of patient sequence data generated using Synthea | ||
at `sample/patient_sequence` in the repo. CEHR-BERT expects the data folder to be named as `patient_sequence` | ||
|
||
```console | ||
PYTHONPATH=./:$PYTHONPATH spark-submit spark_apps/prediction_cohorts/hf_readmission.py -c hf_readmission -i ~/Documents/omop_test/ -o ~/Documents/omop_test/hierarchical_bert -dl 1985-01-01 -du 2020-12-31 -l 18 -u 100 -ow 360 -ps 0 -pw 30 --is_hierarchical_bert | ||
mkdir test_dataset_prepared; | ||
mkdir test_results; | ||
python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_config.yaml | ||
``` | ||
|
||
#### Generate hf_readmission prediction data for CEHR-BERT | ||
If your dataset is large, you could add ```--use_dask``` in the command above | ||
|
||
### 4. Generate hf readmission prediction task | ||
If you don't have your own OMOP instance, we have provided a sample of patient sequence data generated using Synthea | ||
at `sample/hf_readmissioon` in the repo | ||
|
||
```console | ||
PYTHONPATH=./:$PYTHONPATH spark-submit spark_apps/prediction_cohorts/hf_readmission.py -c hf_readmission -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -dl 1985-01-01 -du 2020-12-31 -l 18 -u 100 -ow 360 -ps 0 -pw 30 --is_new_patient_representation | ||
``` | ||
|
||
### 5. Fine-tune Hierarchical CEHR-BERT for hf readmission | ||
#### Fine-tune Hierarchical CEHR-BERT for the hf readmission prediction | ||
### 5. Fine-tune CEHR-BERT | ||
|
||
```console | ||
# Copy the pretrained bert model | ||
cp ~/Documents/omop_test/hierarchical_bert/bert_model_01_* ~/Documents/omop_test/hierarchical_bert/bert_model.h5 | ||
PYTHONPATH=./: python3 evaluations/evaluation.py -a sequence_model -sd ~/Documents/omop_test/hierarchical_bert/hf_readmission -ef ~/Documents/omop_test/evaluation_train_val_split/hf_readmission/ -m 1 -b 32 -p 10 -vb ~/Documents/omop_test/hierarchical_bert -me hierarchical_bert_lstm --sequence_model_name hierarchical_bert_with_phenotype_cross_validation_test --max_num_of_visits 20 --max_num_of_concepts 50 --num_of_folds 3 --cross_validation_test --grid_search_config full_grid_search_config.ini | ||
mkdir test_finetune_results; | ||
python -m cehrbert.runners.hf_cehrbert_finetuning_runner sample_configs/hf_cehrbert_finetuning_runner_config.yaml | ||
``` | ||
|
||
#### Fine-tune CEHR-BERT for the hf readmission prediction | ||
```console | ||
# Copy the pretrained bert model | ||
cp ~/Documents/omop_test/cehr-bert/bert_model_01_* ~/Documents/omop_test/cehr-bert/bert_model.h5; | ||
PYTHONPATH=./: python3 evaluations/evaluation.py -a sequence_model -sd ~/Documents/omop_test/cehr-bert/hf_readmission -ef ~/Documents/omop_test/evaluation_train_val_split/hf_readmission/ -m 512 -b 32 -p 10 -vb ~/Documents/omop_test/cehr-bert -me vanilla_bert_lstm --sequence_model_name CEHR_BERT_512_cross_validation_test --num_of_folds 3 --learning_rate 1e-4 --cross_validation_test --grid_search_config full_grid_search_config.ini; | ||
``` | ||
## Contact us | ||
|
||
If you have any questions, feel free to contact us at [email protected] | ||
|
||
## Citation | ||
|
||
Please acknowledge the following work in papers | ||
|
||
Chao Pang, Xinzhuo Jiang, Krishna S. Kalluri, Matthew Spotnitz, RuiJun Chen, Adler | ||
Perotte, and Karthik Natarajan. "Cehr-bert: Incorporating temporal information from | ||
structured ehr data to improve prediction tasks." In Proceedings of Machine Learning for | ||
Health, volume 158 of Proceedings of Machine Learning Research, pages 239–260. PMLR, | ||
04 Dec 2021. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
model_name_or_path: "test_synthea_results" | ||
tokenizer_name_or_path: "test_synthea_results" | ||
|
||
data_folder: "synthea_meds_reader" | ||
dataset_prepared_path: "test_dataset_prepared" | ||
validation_split_percentage: 0.05 | ||
validation_split_num: 1000 | ||
preprocessing_num_workers: 4 | ||
preprocessing_batch_size: 1000 | ||
streaming: False | ||
|
||
#Tokenizer | ||
vocab_size: 50000 | ||
min_frequency: 0 | ||
|
||
# Below is a list of Med-to-CehrBert related arguments | ||
att_function_type: "cehrbert" | ||
is_data_in_med: true | ||
inpatient_att_function_type: "mix" | ||
include_auxiliary_token: true | ||
include_demographic_prompt: false | ||
|
||
do_train: true | ||
overwrite_output_dir: false | ||
resume_from_checkpoint: # path to the checkpoint folder | ||
seed: 42 | ||
|
||
num_hidden_layers: 6 | ||
max_position_embeddings: 512 | ||
|
||
# torch dataloader configs | ||
dataloader_num_workers: 4 | ||
dataloader_prefetch_factor: 4 | ||
|
||
output_dir: "test_synthea_results" | ||
evaluation_strategy: "epoch" | ||
save_strategy: "epoch" | ||
learning_rate: 0.00005 | ||
per_device_train_batch_size: 4 | ||
per_device_eval_batch_size: 4 | ||
gradient_accumulation_steps: 1 | ||
num_train_epochs: 1 | ||
# When streaming is set to True, max_steps needs to be provided | ||
max_steps: 100 | ||
|
||
warmup_steps: 500 | ||
weight_decay: 0.01 | ||
logging_dir: "./logs" | ||
logging_steps: 100 | ||
save_total_limit: | ||
load_best_model_at_end: true | ||
metric_for_best_model: "eval_loss" | ||
greater_is_better: false |