From 9adabc6b66f5589fb82d758aaf2c8a19e589e353 Mon Sep 17 00:00:00 2001 From: lars Date: Mon, 3 Jun 2024 16:21:57 +0200 Subject: [PATCH] clean up and add conditions=True case --- tests/test_networks/conftest.py | 2 +- .../test_networks/test_inference_networks.py | 20 +++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index dd0727b5..06470102 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -40,7 +40,7 @@ def num_features(request): return request.param -@pytest.fixture(params=[False]) +@pytest.fixture(params=[True, False]) def random_conditions(request, batch_size, num_conditions): if not request.param: return None diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 330a65a1..a615fb93 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -27,7 +27,11 @@ def test_variable_batch_size(inference_network, random_samples, random_condition batch_sizes = np.random.choice(10, replace=False, size=3) for batch_size in batch_sizes: new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:]) - new_conditions = None if random_conditions is None else keras.ops.zeros((batch_size,) + keras.ops.shape(random_conditions)[1:]) + if random_conditions is None: + new_conditions = None + else: + new_conditions = keras.ops.zeros((batch_size,), + keras.ops.shape(random_conditions)[1:]) + inference_network(new_input) inference_network(new_input, conditions=new_conditions, inverse=True) @@ -107,18 +111,4 @@ def test_serialize_deserialize(tmp_path, inference_network, random_samples, rand keras.saving.save_model(inference_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - print(f"{inference_network._layers=}") - print(f"{loaded._layers=}") - print() - dual_coupling1 = inference_network._layers[1] - dual_coupling2 = loaded._layers[1] - print(f"{dual_coupling1.pivot=}") - print(f"{dual_coupling2.pivot=}") - print() - print(f"{dual_coupling1.coupling1.variables=}") - print(f"{dual_coupling1.coupling2.variables=}") - print() - print(f"{dual_coupling2.coupling1.variables=}") - print(f"{dual_coupling2.coupling2.variables=}") - assert_models_equal(inference_network, loaded)