Skip to content

Commit

Permalink
clean up and add conditions=True case
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 3, 2024
1 parent 9a0fc43 commit 9adabc6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def num_features(request):
return request.param


@pytest.fixture(params=[False])
@pytest.fixture(params=[True, False])
def random_conditions(request, batch_size, num_conditions):
if not request.param:
return None
Expand Down
20 changes: 5 additions & 15 deletions tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def test_variable_batch_size(inference_network, random_samples, random_condition
batch_sizes = np.random.choice(10, replace=False, size=3)
for batch_size in batch_sizes:
new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:])
new_conditions = None if random_conditions is None else keras.ops.zeros((batch_size,) + keras.ops.shape(random_conditions)[1:])
if random_conditions is None:
new_conditions = None
else:
new_conditions = keras.ops.zeros((batch_size,), + keras.ops.shape(random_conditions)[1:])

inference_network(new_input)
inference_network(new_input, conditions=new_conditions, inverse=True)

Expand Down Expand Up @@ -107,18 +111,4 @@ def test_serialize_deserialize(tmp_path, inference_network, random_samples, rand
keras.saving.save_model(inference_network, tmp_path / "model.keras")
loaded = keras.saving.load_model(tmp_path / "model.keras")

print(f"{inference_network._layers=}")
print(f"{loaded._layers=}")
print()
dual_coupling1 = inference_network._layers[1]
dual_coupling2 = loaded._layers[1]
print(f"{dual_coupling1.pivot=}")
print(f"{dual_coupling2.pivot=}")
print()
print(f"{dual_coupling1.coupling1.variables=}")
print(f"{dual_coupling1.coupling2.variables=}")
print()
print(f"{dual_coupling2.coupling1.variables=}")
print(f"{dual_coupling2.coupling2.variables=}")

assert_models_equal(inference_network, loaded)

0 comments on commit 9adabc6

Please sign in to comment.