Skip to content

Commit

Permalink
remove incorrect KL Divergence test, improve metric testing
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Aug 22, 2024
1 parent b083364 commit 4647fec
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions tests/test_two_moons/test_two_moons.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_compile(approximator, random_samples, jit_compile):
def test_fit(approximator, train_dataset, validation_dataset, batch_size):
from bayesflow.metrics import MaximumMeanDiscrepancy

approximator.compile(inference_metrics=[keras.metrics.KLDivergence(), MaximumMeanDiscrepancy()])
approximator.compile(inference_metrics=[MaximumMeanDiscrepancy()])

mock_data = train_dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
Expand All @@ -36,22 +36,11 @@ def test_fit(approximator, train_dataset, validation_dataset, batch_size):
assert isinstance(untrained_metrics, dict)
assert isinstance(trained_metrics, dict)

# test loss decreases
assert "loss" in untrained_metrics
assert "loss" in trained_metrics
assert untrained_metrics["loss"] > trained_metrics["loss"]

# test kl divergence decreases
assert "inference/kl_divergence" in untrained_metrics
assert "inference/kl_divergence" in trained_metrics
assert untrained_metrics["inference/kl_divergence"] > trained_metrics["inference/kl_divergence"]

# test mmd decreases
assert "inference/maximum_mean_discrepancy" in untrained_metrics
assert "inference/maximum_mean_discrepancy" in trained_metrics
assert (
untrained_metrics["inference/maximum_mean_discrepancy"] > trained_metrics["inference/maximum_mean_discrepancy"]
)
# test that metrics are improving
for metric in ["loss", "maximum_mean_discrepancy/inference_maximum_mean_discrepancy"]:
assert metric in untrained_metrics
assert metric in trained_metrics
assert trained_metrics[metric] <= untrained_metrics[metric]


@pytest.mark.parametrize("jit_compile", [False, True])
Expand Down

0 comments on commit 4647fec

Please sign in to comment.