Skip to content

Commit

Permalink
add amortizer tests (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed May 28, 2024
1 parent e4e6462 commit 710b8ec
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
Empty file.
47 changes: 47 additions & 0 deletions tests/test_amortizers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

import keras
import pytest

import bayesflow.experimental as bf


@pytest.fixture()
def summary_network():
return None


@pytest.fixture()
def inference_network():
network = keras.Sequential([
keras.layers.Dense(10)
])
network.compile(loss="mse")
return network


@pytest.fixture(params=[bf.AmortizedPosterior, bf.AmortizedLikelihood])
def amortizer(request, inference_network, summary_network):
Amortizer = request.param
return Amortizer(inference_network, summary_network)


@pytest.fixture()
def dataset():
batch_size = 16
batches_per_epoch = 4
parameter_sets = batch_size * batches_per_epoch
observations_per_parameter_set = 32

mean = keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2))
std = keras.ops.exp(keras.random.normal(mean=0.0, stddev=0.1, shape=(parameter_sets, 2)))

mean = keras.ops.repeat(mean[:, None], observations_per_parameter_set, 1)
std = keras.ops.repeat(std[:, None], observations_per_parameter_set, 1)

noise = keras.random.normal(shape=(parameter_sets, observations_per_parameter_set, 2))

x = mean + std * noise

data = dict(observables=dict(x=x), parameters=dict(mean=mean, std=std))

return bf.datasets.OfflineDataset(data, batch_size=batch_size, batches_per_epoch=batches_per_epoch)
11 changes: 11 additions & 0 deletions tests/test_amortizers/test_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def test_compile(amortizer):
amortizer.compile(optimizer="AdamW")


def test_fit(amortizer, dataset):
amortizer.compile(optimizer="AdamW")
amortizer.fit(dataset)

assert amortizer.losses is not None


0 comments on commit 710b8ec

Please sign in to comment.