Skip to content

Commit

Permalink
Merge pull request #91 from elseml/Development
Browse files Browse the repository at this point in the history
Fixes #90 (axis labeling)
  • Loading branch information
stefanradev93 authored Jul 25, 2023
2 parents d0c9f1e + 298d4b6 commit 60cb270
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions bayesflow/computational_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
# Extract number of models and prepare containers
n_models = m_true.shape[1]
cal_errs = []
probs = []
probs_true = []
probs_pred = []

# Loop for each model and compute calibration errs per bin
for k in range(n_models):
Expand All @@ -295,8 +296,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
cal_err = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))

cal_errs.append(cal_err)
probs.append((prob_true, prob_pred))
return cal_errs, probs
probs_true.append(prob_true)
probs_pred.append(prob_pred)
return cal_errs, probs_true, probs_pred


def maximum_mean_discrepancy(source_samples, target_samples, kernel="gaussian", mmd_weight=1.0, minimum=0.0):
Expand Down
4 changes: 2 additions & 2 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ def plot_calibration_curves(
# Determine n_subplots dynamically
n_row = int(np.ceil(num_models / 6))
n_col = int(np.ceil(num_models / n_row))
cal_errs, cal_probs = expected_calibration_error(true_models, pred_models, num_bins)
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)

# Initialize figure
if fig_size is None:
Expand All @@ -1073,7 +1073,7 @@ def plot_calibration_curves(
ax = axarr
for j in range(num_models):
# Plot calibration curve
ax[j].plot(cal_probs[j][0], cal_probs[j][1], color=color)
ax[j].plot(probs_pred[j], probs_true[j], color=color)

# Plot AB line
ax[j].plot(ax[j].get_xlim(), ax[j].get_xlim(), "--", color="darkgrey")
Expand Down

0 comments on commit 60cb270

Please sign in to comment.