-
Notifications
You must be signed in to change notification settings - Fork 19
/
config_sepsis.yaml
42 lines (42 loc) · 3.53 KB
/
config_sepsis.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
##########################################################################################################
# Global PARAMETERS
device: 'cpu'
random_seed: 2345 # For the numpy as well as pytorch random number generators
num_experiments: 1 # Num of experiments, which will run in serial (better to run different exp using separate runs with various rand. seeds)
folder_location: 'results/'
folder_name: 'run1'
# Data files -- created by first running `data_process.py` after copying MIMIC data extracted with `microsoft/mimic_sepsis` repo
train_data_file: 'data/sepsis_mimiciii/sepsis_final_data_K1_train.csv'
validation_data_file: 'data/sepsis_mimiciii/sepsis_final_data_K1_validation.csv'
test_data_file: 'data/sepsis_mimiciii/sepsis_final_data_K1_test.csv'
##########################################################################################################
# Experiment PARAMETERS
exp_num_epochs: 100 # RL training epochs
exp_saving_period: 1 # The frequency with which checkpoints are saved
# Whether resuming previously started experiment
rl_resume: False # Resuming training of the Q_D and Q_R networks
sc_resume: False # Resuming training of the State Construction network
minibatch_size: 64 # minibatch size for dataloading of constructed states when training the Q-networks
sided_Q: 'negative' # Constrain the values of the Q-function according to the MDP design. Can be 'positive' (R-Network) or 'negative' (D-Network)
##########################################################################################################
# SC-Network PARAMETERS (Training the SC-Network)
sc_method: 'AIS' # 'AIS' or 'Observations'
sc_resume: False # Resumes State Constructor Training from checkpoint.pt
sc_num_epochs: 400 # Number of epochs used for training the SC-Network. Default 200
embed_state_dim: 64 # Dimension of embeded state from SC-Network. Default 64
sc_learning_rate: 0.0001 # Learning rate. Default 0.0001
obs_dim: 44 # Number of feature dimensions extracted from the MIMIC EMR. Default 44
num_actions: 25 # Number of treatment options. Default 25
ais_gen_model: 1 # Indicator of which Generator to use for AIS encoding of observations (see `cortex.py`)
ais_pred_model: 1 # Inidicator of which Predictor to use for AIS reconstruction (see `cortex.py`)
sc_neg_traj_ratio: 'NA' # Means by which we can rebalance the ratio of positive or negative trajectories sampled in a batch. 'NA' = not computed, float value is the desired ratio of negative (death) trajectories in a minibatch
sc_saving_period: 1 # The per-epoch frequency to run a validation loop and potentially save an improved set of model parameters
save_all_checkpoints: False # If, for testing purposes, all checkpoints through training can be saved (saved as `checkpointXXX.pt`)
##########################################################################################################
# RL PARAMETERS (Training the D- and R- Networks)
rl_network_size: 'small' # Can be 'small', 'large' or '2layered'
gamma: 1.0 # Must be 1 due to the dead-end theory
rl_learning_rate: 0.0001 # Learning rate. Default 0.0001
use_ddqn: True # Whether to train DQN with DDQN approach
update_freq: 2 # Frequency of updates to Q functions. Default 2
##########################################################################################################