Skip to content

Commit

Permalink
Slight naming change
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed May 31, 2024
1 parent a303d87 commit 2678d07
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class SingleCoupling(InvertibleLayer):
"""
def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros")
self.output_projector = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros")
self.network = find_network(network)
self.transform = find_transform(transform)

# noinspection PyMethodOverriding
def build(self, x1_shape, x2_shape):
self.dense.units = self.transform.params_per_dim * x2_shape[-1]
self.output_projector.units = self.transform.params_per_dim * x2_shape[-1]

def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
if inverse:
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]:
if keras.ops.is_tensor(conditions):
x = keras.ops.concatenate([x, conditions], axis=-1)

parameters = self.dense(self.network(x))
parameters = self.output_projector(self.network(x))
parameters = self.transform.split_parameters(parameters)
parameters = self.transform.constrain_parameters(parameters)

Expand Down

0 comments on commit 2678d07

Please sign in to comment.