Skip to content

Commit

Permalink
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
Browse files Browse the repository at this point in the history
…3/BayesFlow into streamlined-backend
  • Loading branch information
stefanradev93 committed May 31, 2024
2 parents 4745c86 + 493c522 commit 9d35227
Show file tree
Hide file tree
Showing 18 changed files with 292 additions and 702 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import keras
from keras.saving import (
register_keras_serializable
)

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import find_network
from ..invertible_layer import InvertibleLayer
from ..transforms import find_transform


@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class SingleCoupling(InvertibleLayer):
"""
Implements a single coupling layer as a composition of a subnet and a transform.
Subnet output tensors are linearly mapped to the correct dimension.
"""
def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros")
self.network = find_network(network)
self.transform = find_transform(transform)

# noinspection PyMethodOverriding
def build(self, x1_shape, x2_shape):
self.dense.units = self.transform.params_per_dim * x2_shape[-1]

def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
if inverse:
return self._inverse(x1, x2, conditions=conditions)
return self._forward(x1, x2, conditions=conditions)

def _forward(self, x1: Tensor, x2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, x2) -> (x1, f(x2; x1)) """
z1 = x1
parameters = self.get_parameters(x1, conditions)
z2, log_det = self.transform(x2, parameters=parameters)

return (z1, z2), log_det

def _inverse(self, z1: Tensor, z2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor):
""" Transform (x1, f(x2; x1)) -> (x1, x2) """
x1 = z1
parameters = self.get_parameters(x1, conditions)
x2, log_det = self.transform(z2, parameters=parameters, inverse=True)

return (x1, x2), log_det

def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]:
# TODO: pass conditions to subnet via kwarg if possible
if keras.ops.is_tensor(conditions):
x = keras.ops.concatenate([x, conditions], axis=-1)

parameters = self.dense(self.network(x))
parameters = self.transform.split_parameters(parameters)
parameters = self.transform.constrain_parameters(parameters)

return parameters
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from .orthogonal import OrthogonalPermutation
from .fixed_permutation import FixedPermutation
from .swap import Swap
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, forward_indices=None, inverse_indices=None, **kwargs):
self.forward_indices = forward_indices
self.inverse_indices = inverse_indices

def call(self, xz: Tensor, inverse: bool = False):
def call(self, xz: Tensor, inverse: bool = False, **kwargs):
if inverse:
return self._inverse(xz)
return self._forward(xz)
Expand All @@ -25,10 +25,10 @@ def build(self, input_shape: Shape) -> None:

def _forward(self, x: Tensor) -> (Tensor, Tensor):
z = keras.ops.take(x, self.forward_indices, axis=-1)
log_det = keras.ops.zeros(keras.ops.shape(x)[0])
log_det = 0.
return z, log_det

def _inverse(self, z: Tensor) -> (Tensor, Tensor):
x = keras.ops.take(z, self.inverse_indices, axis=-1)
log_det = keras.ops.zeros(keras.ops.shape(z)[0])
log_det = 0.
return x, log_det
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@

from keras import ops
from keras.saving import (
register_keras_serializable,
)

from bayesflow.experimental.types import Shape, Tensor
from ..invertible_layer import InvertibleLayer


@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class OrthogonalPermutation(InvertibleLayer):
"""Implements a learnable orthogonal transformation according to [1]. Can be
used as an alternative to a fixed ``Permutation`` layer.
Expand All @@ -25,7 +29,7 @@ def build(self, input_shape: Shape) -> None:
trainable=True
)

def call(self, xz: Tensor, inverse: bool = False):
def call(self, xz: Tensor, inverse: bool = False, **kwargs):
if inverse:
return self._inverse(xz)
return self._forward(xz)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def build(self, input_shape: Shape) -> None:
forward_indices = keras.random.shuffle(keras.ops.arange(input_shape[-1]))
inverse_indices = keras.ops.argsort(forward_indices)

self.forward_indices = self.add_variable(
self.forward_indices = self.add_weight(
shape=(input_shape[-1],),
initializer=keras.initializers.Constant(forward_indices),
trainable=False
)

self.inverse_indices = self.add_variable(
self.inverse_indices = self.add_weight(
shape=(input_shape[-1],),
initializer=keras.initializers.Constant(inverse_indices),
trainable=False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

import keras
from keras.saving import (
register_keras_serializable,
)

from bayesflow.experimental.types import Shape
from .fixed_permutation import FixedPermutation
Expand Down
23 changes: 0 additions & 23 deletions bayesflow/experimental/networks/coupling_flow/subnets/__init__.py

This file was deleted.

Loading

0 comments on commit 9d35227

Please sign in to comment.