From 2678d07da6a08c453adaaba71b2e2ac7e86da94e Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 31 May 2024 06:54:12 -0400 Subject: [PATCH] Slight naming change --- .../networks/coupling_flow/couplings/single_coupling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py index 01e68dbb..9205eeff 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -19,13 +19,13 @@ class SingleCoupling(InvertibleLayer): """ def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs): super().__init__(**kwargs) - self.dense = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros") + self.output_projector = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros") self.network = find_network(network) self.transform = find_transform(transform) # noinspection PyMethodOverriding def build(self, x1_shape, x2_shape): - self.dense.units = self.transform.params_per_dim * x2_shape[-1] + self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor): if inverse: @@ -53,7 +53,7 @@ def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]: if keras.ops.is_tensor(conditions): x = keras.ops.concatenate([x, conditions], axis=-1) - parameters = self.dense(self.network(x)) + parameters = self.output_projector(self.network(x)) parameters = self.transform.split_parameters(parameters) parameters = self.transform.constrain_parameters(parameters)