-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
108cf02
commit 0fc4f00
Showing
2 changed files
with
726 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from __future__ import print_function\n", | ||
"\n", | ||
"# Import libraries\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"import matplotlib\n", | ||
"import sklearn\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib.font_manager import FontProperties # for unicode fonts\n", | ||
"import sys\n", | ||
"import datetime as dt\n", | ||
"import mp_utils as mp\n", | ||
"\n", | ||
"from collections import OrderedDict\n", | ||
"\n", | ||
"# used to print out pretty pandas dataframes\n", | ||
"from IPython.display import display, HTML\n", | ||
"\n", | ||
"from sklearn.pipeline import Pipeline\n", | ||
"from sklearn.utils import shuffle\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"\n", | ||
"# used to impute mean for data and standardize for computational stability\n", | ||
"from sklearn.preprocessing import Imputer\n", | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"\n", | ||
"# logistic regression is our favourite model ever\n", | ||
"from sklearn import linear_model\n", | ||
"from sklearn import ensemble\n", | ||
"\n", | ||
"# used to calculate AUROC/accuracy\n", | ||
"from sklearn import metrics\n", | ||
"\n", | ||
"# used to create confusion matrix\n", | ||
"from sklearn.metrics import confusion_matrix\n", | ||
"from sklearn.model_selection import cross_val_score\n", | ||
"\n", | ||
"# gradient boosting - must download package https://github.com/dmlc/xgboost\n", | ||
"import xgboost as xgb\n", | ||
"\n", | ||
"#from eli5 import show_weights\n", | ||
"\n", | ||
"# default colours for prettier plots\n", | ||
"col = [[0.9047, 0.1918, 0.1988],\n", | ||
" [0.2941, 0.5447, 0.7494],\n", | ||
" [0.3718, 0.7176, 0.3612],\n", | ||
" [1.0000, 0.5482, 0.1000],\n", | ||
" [0.4550, 0.4946, 0.4722],\n", | ||
" [0.6859, 0.4035, 0.2412],\n", | ||
" [0.9718, 0.5553, 0.7741],\n", | ||
" [0.5313, 0.3359, 0.6523]];\n", | ||
"marker = ['v','o','d','^','s','o','+']\n", | ||
"ls = ['-','-','-','-','-','s','--','--']\n", | ||
"\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"50488 observations. Outcome rate: 11.20%.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"df = pd.read_csv('X_design_matrix.csv').set_index('icustay_id')\n", | ||
"\n", | ||
"# create X by dropping idxK and the outcome\n", | ||
"y = df['death'].values\n", | ||
"idxK = df['idxK']\n", | ||
"X = df.drop(['death','idxK'],axis=1).values\n", | ||
"\n", | ||
"print('{} observations. Outcome rate: {:2.2f}%.'.format(X.shape[0],\n", | ||
" 100.0*np.mean(y)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"source": [ | ||
" colsample_bytree - 0.7 \t[ 0.7 0.7 0.7 0.7 0.7 ]\n", | ||
" silent - 1.0 \t[ 1 1 1 1 1 ]\n", | ||
" learning_rate - 0.01 \t[ 0.01 0.01 0.01 0.01 0.01 ]\n", | ||
" n_estimators - 1000.0 \t[ 1000 1000 1000 1000 1000 ]\n", | ||
" subsample - 0.8 \t[ 0.8 0.8 0.8 0.8 0.8 ]\n", | ||
" objective - skipping as it is a string\n", | ||
" max_depth - 9.0 \t[ 9 6 9 9 6 ]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"=============== l2 ===============\n", | ||
"2017-07-06 08:44:36.522302 - Finished fold 1 of 5. AUROC 0.893.\n", | ||
"2017-07-06 08:44:44.736272 - Finished fold 2 of 5. AUROC 0.896.\n", | ||
"2017-07-06 08:44:51.897262 - Finished fold 3 of 5. AUROC 0.893.\n", | ||
"2017-07-06 08:44:58.932115 - Finished fold 4 of 5. AUROC 0.890.\n", | ||
"2017-07-06 08:45:06.853988 - Finished fold 5 of 5. AUROC 0.888.\n", | ||
"=============== xgb ===============\n", | ||
"2017-07-06 08:48:32.392924 - Finished fold 1 of 5. AUROC 0.924.\n", | ||
"2017-07-06 08:51:58.636281 - Finished fold 2 of 5. AUROC 0.920.\n", | ||
"2017-07-06 08:55:32.974751 - Finished fold 3 of 5. AUROC 0.921.\n", | ||
"2017-07-06 08:59:12.003701 - Finished fold 4 of 5. AUROC 0.918.\n", | ||
"2017-07-06 09:03:00.009820 - Finished fold 5 of 5. AUROC 0.919.\n", | ||
"=============== logreg ===============\n", | ||
"2017-07-06 09:03:03.046011 - Finished fold 1 of 5. AUROC 0.893.\n", | ||
"2017-07-06 09:03:06.030786 - Finished fold 2 of 5. AUROC 0.896.\n", | ||
"2017-07-06 09:03:08.907647 - Finished fold 3 of 5. AUROC 0.893.\n", | ||
"2017-07-06 09:03:11.748804 - Finished fold 4 of 5. AUROC 0.890.\n", | ||
"2017-07-06 09:03:15.194134 - Finished fold 5 of 5. AUROC 0.887.\n", | ||
"=============== lasso ===============\n", | ||
"2017-07-06 09:03:18.332521 - Finished fold 1 of 5. AUROC 0.890.\n", | ||
"2017-07-06 09:03:23.226927 - Finished fold 2 of 5. AUROC 0.894.\n", | ||
"2017-07-06 09:03:28.012883 - Finished fold 3 of 5. AUROC 0.888.\n", | ||
"2017-07-06 09:03:31.130018 - Finished fold 4 of 5. AUROC 0.885.\n", | ||
"2017-07-06 09:03:33.840082 - Finished fold 5 of 5. AUROC 0.882.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Rough timing info:\n", | ||
"# rf - 3 seconds per fold\n", | ||
"# xgb - 30 seconds per fold\n", | ||
"# logreg - 4 seconds per fold\n", | ||
"# lasso - 8 seconds per fold\n", | ||
"np.random.seed(7390984)\n", | ||
"\n", | ||
"# parameters from grid search\n", | ||
"xgb_mdl = xgb.XGBClassifier(colsample_bytree=0.7, silent=1,\n", | ||
" learning_rate = 0.01, n_estimators=1000,\n", | ||
" subsample=0.8, max_depth=9)\n", | ||
"\n", | ||
"models = {'xgb': xgb_mdl,\n", | ||
" 'lasso': linear_model.LassoCV(cv=5,fit_intercept=True,normalize=True,max_iter=10000),\n", | ||
" 'logreg': linear_model.LogisticRegression(fit_intercept=True),\n", | ||
" 'l2': linear_model.LogisticRegressionCV()\n", | ||
" #'rf': ensemble.RandomForestClassifier()\n", | ||
" }\n", | ||
"\n", | ||
"\n", | ||
"# create k-fold indices\n", | ||
"K = 5 # number of folds\n", | ||
"idxK = np.random.permutation(X.shape[0])\n", | ||
"idxK = np.mod(idxK,K)\n", | ||
"\n", | ||
"mdl_val = dict()\n", | ||
"results_val = dict()\n", | ||
"pred_val = dict()\n", | ||
"pred_val_merged = dict()\n", | ||
"for mdl in models:\n", | ||
" print('=============== {} ==============='.format(mdl))\n", | ||
" mdl_val[mdl] = list()\n", | ||
" results_val[mdl] = list() # initialize list for scores\n", | ||
" pred_val[mdl] = dict()\n", | ||
" pred_val_merged[mdl] = np.zeros(X.shape[0])\n", | ||
" \n", | ||
" if mdl == 'xgb':\n", | ||
" # no pre-processing of data necessary for xgb\n", | ||
" estimator = Pipeline([(mdl, models[mdl])])\n", | ||
"\n", | ||
" else:\n", | ||
" estimator = Pipeline([(\"imputer\", Imputer(missing_values='NaN',\n", | ||
" strategy=\"mean\",\n", | ||
" axis=0)),\n", | ||
" (\"scaler\", StandardScaler()),\n", | ||
" (mdl, models[mdl])]) \n", | ||
"\n", | ||
" for k in range(K):\n", | ||
" # train the model using all but the kth fold\n", | ||
" curr_mdl = sklearn.base.clone(estimator).fit(X[idxK != k, :], y[idxK != k])\n", | ||
"\n", | ||
" # get prediction on this dataset\n", | ||
" if mdl in ('lasso','ridge'):\n", | ||
" curr_prob = curr_mdl.predict(X[idxK == k, :])\n", | ||
" else:\n", | ||
" curr_prob = curr_mdl.predict_proba(X[idxK == k, :])\n", | ||
" curr_prob = curr_prob[:,1]\n", | ||
" \n", | ||
" pred_val_merged[mdl][idxK==k] = curr_prob\n", | ||
" pred_val[mdl][k] = curr_prob\n", | ||
"\n", | ||
" # calculate score (AUROC)\n", | ||
" curr_score = metrics.roc_auc_score(y[idxK == k], curr_prob)\n", | ||
"\n", | ||
" # add score to list of scores\n", | ||
" results_val[mdl].append(curr_score)\n", | ||
"\n", | ||
" # save the current model\n", | ||
" mdl_val[mdl].append(curr_mdl)\n", | ||
" \n", | ||
" print('{} - Finished fold {} of {}. AUROC {:0.3f}.'.format(dt.datetime.now(), k+1, K, curr_score))\n", | ||
" \n", | ||
"tar_val = dict()\n", | ||
"for k in range(K):\n", | ||
" tar_val[k] = y[idxK==k]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"l2\t0.892 [0.888, 0.896]\n", | ||
"xgb\t0.920 [0.918, 0.924]\n", | ||
"logreg\t0.892 [0.887, 0.896]\n", | ||
"lasso\t0.888 [0.882, 0.894]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# average AUROC + min/max\n", | ||
"for mdl in models:\n", | ||
" curr_score = np.zeros(K)\n", | ||
" for k in range(K):\n", | ||
" curr_score[k] = metrics.roc_auc_score(tar_val[k], pred_val[mdl][k])\n", | ||
" print('{}\\t{:0.3f} [{:0.3f}, {:0.3f}]'.format(mdl, np.mean(curr_score), np.min(curr_score), np.max(curr_score)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 28, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"l2\t0.588 [0.568, 0.597]\n", | ||
"xgb\t0.665 [0.654, 0.669]\n", | ||
"logreg\t0.588 [0.568, 0.597]\n", | ||
"lasso\t0.579 [0.557, 0.593]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# average AUPRC + min/max\n", | ||
"for mdl in models:\n", | ||
" curr_score = np.zeros(K)\n", | ||
" for k in range(K):\n", | ||
" curr_score[k] = metrics.average_precision_score(tar_val[k], pred_val[mdl][k])\n", | ||
" print('{}\\t{:0.3f} [{:0.3f}, {:0.3f}]'.format(mdl, np.mean(curr_score), np.min(curr_score), np.max(curr_score)))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.