Skip to content

Commit

Permalink
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
Browse files Browse the repository at this point in the history
…3/BayesFlow into streamlined-backend
  • Loading branch information
stefanradev93 committed Jun 3, 2024
2 parents de0b91b + 1111802 commit 76c225f
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
5 changes: 5 additions & 0 deletions tests/test_amortizers/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@

import pytest

@pytest.mark.skip(reason="not implemented")
def test_compile(amortizer):
amortizer.compile(optimizer="AdamW")


@pytest.mark.skip(reason="not implemented")
def test_fit(amortizer, dataset):
amortizer.compile(optimizer="AdamW")
amortizer.fit(dataset)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def num_features(request):
return request.param


@pytest.fixture(params=[False])
@pytest.fixture(params=[True, False])
def random_conditions(request, batch_size, num_conditions):
if not request.param:
return None
Expand Down
20 changes: 5 additions & 15 deletions tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
batch_sizes = np.random.choice(10, replace=False, size=3)
for batch_size in batch_sizes:
new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:])
new_conditions = None if random_conditions is None else keras.ops.zeros((batch_size,) + keras.ops.shape(random_conditions)[1:])
if random_conditions is None:
new_conditions = None
else:
new_conditions = keras.ops.zeros((batch_size,), + keras.ops.shape(random_conditions)[1:])

inference_network(new_input)
inference_network(new_input, conditions=new_conditions, inverse=True)

Expand Down Expand Up @@ -107,18 +111,4 @@ def test_serialize_deserialize(tmp_path, inference_network, random_samples, rand
keras.saving.save_model(inference_network, tmp_path / "model.keras")
loaded = keras.saving.load_model(tmp_path / "model.keras")

print(f"{inference_network._layers=}")
print(f"{loaded._layers=}")
print()
dual_coupling1 = inference_network._layers[1]
dual_coupling2 = loaded._layers[1]
print(f"{dual_coupling1.pivot=}")
print(f"{dual_coupling2.pivot=}")
print()
print(f"{dual_coupling1.coupling1.variables=}")
print(f"{dual_coupling1.coupling2.variables=}")
print()
print(f"{dual_coupling2.coupling1.variables=}")
print(f"{dual_coupling2.coupling2.variables=}")

assert_models_equal(inference_network, loaded)
3 changes: 3 additions & 0 deletions tests/test_two_moons/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
from tests.utils import InterruptFitCallback, FitInterruptedError


@pytest.mark.skip(reason="not implemented")
def test_compile(amortizer):
amortizer.compile(optimizer="AdamW")


@pytest.mark.skip(reason="not implemented")
def test_fit(amortizer, dataset):
# TODO: verify the model learns something by comparing a metric before and after training
amortizer.compile(optimizer="AdamW")
amortizer.fit(dataset, epochs=10, steps_per_epoch=10, batch_size=32)


@pytest.mark.skip(reason="not implemented")
def test_interrupt_and_resume_fit(tmp_path, amortizer, dataset):
# TODO: test the InterruptFitCallback
amortizer.compile(optimizer="AdamW")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_two_moons/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@

import keras
import pytest

from tests.utils import assert_layers_equal


@pytest.mark.skip(reason="not implemented")
def test_save_and_load(tmp_path, amortizer):
amortizer.save(tmp_path / "amortizer.keras")
loaded_amortizer = keras.saving.load_model(tmp_path / "amortizer.keras")
Expand Down

0 comments on commit 76c225f

Please sign in to comment.