From ae7f8c4444c1ca599614b858462be18a629a37e0 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Sun, 13 Aug 2023 16:48:22 +0200 Subject: [PATCH] Add option to aggregate outputs of c2st --- bayesflow/computational_utilities.py | 46 ++++++++++++++++++---------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/bayesflow/computational_utilities.py b/bayesflow/computational_utilities.py index 451214f8..b047edee 100644 --- a/bayesflow/computational_utilities.py +++ b/bayesflow/computational_utilities.py @@ -22,8 +22,8 @@ import tensorflow as tf from scipy import stats from sklearn.calibration import calibration_curve +from sklearn.model_selection import KFold, cross_val_score from sklearn.neural_network import MLPClassifier -from sklearn.model_selection import cross_val_score, KFold from bayesflow.default_settings import MMD_BANDWIDTH_LIST from bayesflow.exceptions import ShapeError @@ -521,30 +521,41 @@ def aggregated_rmse(x_true, x_pred): ) -def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normalize=True, seed=123, - hidden_units_per_dim=10): - """C2ST metric [1] using an sklearn MLP classifier. +def c2st( + source_samples, + target_samples, + n_folds=5, + scoring="accuracy", + normalize=True, + seed=123, + hidden_units_per_dim=16, + aggregate_output=True, +): + """C2ST metric [1] using an sklearn neural network classifier (i.e., MLP). Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py [1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545. Parameters ---------- - source_samples : np.ndarray or tf.Tensor + source_samples : np.ndarray or tf.Tensor Source samples (e.g., approximate posterior samples) - target_samples : np.ndarray or tf.Tensor + target_samples : np.ndarray or tf.Tensor Target samples (e.g., samples from a reference posterior) - n_folds : int, optional, default: 5 + n_folds : int, optional, default: 5 Number of folds in k-fold cross-validation for the classifier evaluation - scoring : str, optional, default: "accuracy" + scoring : str, optional, default: "accuracy" Evaluation score of the sklearn MLP classifier - normalize : bool, optional, default: True + normalize : bool, optional, default: True Whether the data shall be z-standardized relative to source_samples - seed : int, optional, default: 123 + seed : int, optional, default: 123 RNG seed for the MLP and k-fold CV - hidden_units_per_dim : int, optional, default: 10 + hidden_units_per_dim : int, optional, default: 16 Number of hidden units in the MLP, relative to the input dimensions. - Example: source samples are 5D, hidden_units_per_dim=10 -> 50 hidden units per layer + Example: source samples are 5D, hidden_units_per_dim=16 -> 80 hidden units per layer + aggregate_output : bool, optional, default: True + Whether to return a single value aggregated over all cross-validation runs + or all values from all runs. If left at default, the empirical mean will be returned Returns ------- @@ -558,9 +569,11 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz num_dims = x.shape[1] if not num_dims == y.shape[1]: - raise ShapeError(f"source_samples and target_samples can have different number of observations (1st dim)" - f"but must have the same dimensionality (2nd dim)" - f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}") + raise ShapeError( + f"source_samples and target_samples can have different number of observations (1st dim)" + f"but must have the same dimensionality (2nd dim)" + f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}" + ) if normalize: x_mean = np.mean(x, axis=0) @@ -587,5 +600,6 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed) scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring) - c2st_score = np.asarray(np.mean(scores)).astype(np.float32) + if aggregate_output: + c2st_score = np.asarray(np.mean(scores)).astype(np.float32) return c2st_score