Skip to content

Commit

Permalink
fix tests for preliminary end-to-end version
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 7, 2024
1 parent 3f8a9d5 commit e105d63
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 70 deletions.
49 changes: 0 additions & 49 deletions tests/conftest.py

This file was deleted.

22 changes: 10 additions & 12 deletions tests/test_two_moons/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import keras
import pytest

import bayesflow.experimental as bf


@pytest.fixture()
def batch_size():
Expand All @@ -19,33 +17,33 @@ def sample(self, batch_shape):
alpha = keras.random.uniform(shape=batch_shape + (1,), minval=-0.5 * math.pi, maxval=0.5 * math.pi)
theta = keras.random.uniform(shape=batch_shape + (2,), minval=-1.0, maxval=1.0)

x1 = -keras.ops.abs(theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.cos(alpha) + 0.25
x2 = (-theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.sin(alpha)
x1 = -keras.ops.abs(theta[..., :1] + theta[..., 1:]) / keras.ops.sqrt(2.0) + r * keras.ops.cos(alpha) + 0.25
x2 = (-theta[..., :1] + theta[..., 1:]) / keras.ops.sqrt(2.0) + r * keras.ops.sin(alpha)

x = keras.ops.stack([x1, x2], axis=-1)
x = keras.ops.concatenate([x1, x2], axis=-1)

return dict(r=r, alpha=alpha, theta=theta, x=x)

return Simulator()


@pytest.fixture()
def dataset(joint_distribution):
return bf.datasets.OnlineDataset(joint_distribution, workers=4, use_multiprocessing=True, max_queue_size=16, batch_size=16)
def dataset(simulator):
from bayesflow.experimental.datasets import OnlineDataset
return OnlineDataset(simulator, workers=4, max_queue_size=16, batch_size=16)


@pytest.fixture()
def inference_network():
return bf.networks.CouplingFlow()
from bayesflow.experimental.networks import CouplingFlow
return CouplingFlow()


@pytest.fixture()
def approximator(inference_network):
return bf.Approximator(
from bayesflow.experimental.backend_approximators import Approximator
return Approximator(
inference_network=inference_network,
inference_variables=["theta"],
inference_conditions=["x", "r", "alpha"],
summary_network=None,
summary_variables=[],
summary_conditions=[],
)
12 changes: 3 additions & 9 deletions tests/test_two_moons/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@
from tests.utils import InterruptFitCallback, FitInterruptedError


@pytest.mark.skip(reason="not implemented")
def test_compile(amortizer):
amortizer.compile(optimizer="AdamW")


@pytest.mark.skip(reason="not implemented")
def test_fit(amortizer, dataset):
def test_fit(approximator, dataset):
# TODO: verify the model learns something by comparing a metric before and after training
amortizer.compile(optimizer="AdamW")
amortizer.fit(dataset, epochs=10, steps_per_epoch=10)
approximator.compile(optimizer="AdamW")
approximator.fit(dataset, epochs=10, steps_per_epoch=10)


@pytest.mark.skip(reason="not implemented")
Expand Down

0 comments on commit e105d63

Please sign in to comment.