Skip to content

Commit

Permalink
remove factory methods
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed May 29, 2024
1 parent 1cfcbea commit cc69a38
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 318 deletions.
2 changes: 1 addition & 1 deletion bayesflow/experimental/datasets/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import keras

from bayesflow.experimental.simulation.distributions import JointDistribution
from bayesflow.experimental.simulation import JointDistribution


class OnlineDataset(keras.utils.PyDataset):
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/experimental/datasets/rounds_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import keras

from bayesflow.experimental.simulation.distributions.joint_distribution import JointDistribution
from bayesflow.experimental.simulation import JointDistribution


class RoundsDataset(keras.utils.PyDataset):
Expand Down
17 changes: 17 additions & 0 deletions bayesflow/experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@

from .distribution import Distribution
from .normal import Normal


def find_distribution(distribution: str | Distribution | type(Distribution)) -> Distribution:
if isinstance(distribution, Distribution):
return distribution
if isinstance(distribution, type):
return Distribution()

match distribution:
case "normal":
distribution = Normal()
case str() as unknown_distribution:
raise ValueError(f"Distribution '{unknown_distribution}' is unknown or not yet supported by name.")
case other:
raise TypeError(f"Unknown distribution type: {other}")

return distribution
109 changes: 34 additions & 75 deletions bayesflow/experimental/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@

from typing import Sequence
from typing import Tuple, Union

import keras
from keras.saving import (
deserialize_keras_object,
register_keras_serializable,
serialize_keras_object,
)

from bayesflow.experimental.types import Shape, Tensor
from bayesflow.experimental.types import Tensor
from .actnorm import ActNorm
from .couplings import DualCoupling
from .invertible_layer import InvertibleLayer

from ...simulation.distributions import find_distribution
from ..inference_network import InferenceNetwork


@register_keras_serializable(package="bayesflow.networks")
class CouplingFlow(keras.Model, InvertibleLayer):
class CouplingFlow(InferenceNetwork):
""" Implements a coupling flow as a sequence of dual couplings with permutations and activation
normalization. Incorporates ideas from [1-4].
normalization. Incorporates ideas from [1-5].
[1] Kingma, D. P., & Dhariwal, P. (2018).
Glow: Generative flow with invertible 1x1 convolutions.
Expand All @@ -40,92 +36,55 @@ class CouplingFlow(keras.Model, InvertibleLayer):
Robust model training and generalisation with Studentising flows.
arXiv preprint arXiv:2006.06599.
"""
def __init__(self, invertible_layers: Sequence[InvertibleLayer], base_distribution: str = "normal", **kwargs):
super().__init__(**kwargs)
self.invertible_layers = list(invertible_layers)
self.base_distribution = find_distribution(base_distribution)

@classmethod
def new(
cls,
depth: int = 6,
subnet: str = "resnet",
transform: str = "affine",
base_distribution: str = "normal",
use_actnorm: bool = True,
**kwargs
def __init__(
self,
depth: int = 6,
subnet: str = "resnet",
transform: str = "affine",
use_actnorm: bool = True, **kwargs
):
super().__init__(**kwargs)

layers = []
for i in range(depth):
self._layers = []
for _ in range(depth):
if use_actnorm:
layers.append(ActNorm())
layers.append(DualCoupling.new(subnet, transform))

return cls(layers, base_distribution, **kwargs)

@classmethod
def from_config(cls, config, custom_objects=None):
couplings = deserialize_keras_object(config.pop("invertible_layers"))
base_distribution = deserialize_keras_object(config.pop("base_distribution"))

return cls(couplings, base_distribution, **config)

def get_config(self):
base_config = super().get_config()

config = {
"invertible_layers": serialize_keras_object(self.invertible_layers),
"base_distribution": serialize_keras_object(self.base_distribution),
}

return base_config | config
self._layers.append(ActNorm())
self._layers.append(DualCoupling(subnet, transform))

def build(self, input_shape):
# nothing to do here, since we do not know the conditions yet
self.base_distribution.build(input_shape)
super().build(input_shape)
self.call(keras.KerasTensor(input_shape))

def call(self, x: Tensor, conditions: any = None, inverse: bool = False) -> (Tensor, Tensor):
def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if inverse:
return self._inverse(x, conditions)
return self._forward(x, conditions)
return self._inverse(xz, **kwargs)
return self._forward(xz, **kwargs)

def _forward(self, x: Tensor, conditions: any = None) -> (Tensor, Tensor):
def _forward(self, x: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
z = x
log_det = 0.0
for layer in self.invertible_layers:
z, det = layer(z, conditions=conditions)
for layer in self._layers:
z, det = layer(z, inverse=False, **kwargs)
log_det += det

return z, log_det
if jacobian:
return z, log_det
return z

def _inverse(self, z: Tensor, conditions: any = None) -> (Tensor, Tensor):
def _inverse(self, z: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
x = z
log_det = 0.0
for layer in reversed(self.invertible_layers):
x, det = layer(x, conditions=conditions, inverse=True)
for layer in reversed(self._layers):
x, det = layer(x, inverse=True, **kwargs)
log_det += det

return x, log_det

def sample(self, batch_shape: Shape, conditions=None) -> Tensor:
z = self.base_distribution.sample(batch_shape)
x, _ = self(z, conditions, inverse=True)

if jacobian:
return x, log_det
return x

def log_prob(self, x: Tensor, conditions=None) -> Tensor:
z, log_det = self(x, conditions)
log_prob = self.base_distribution.log_prob(z)

return log_prob + log_det

def compute_loss(self, x=None, y=None, y_pred=None, **kwargs):
z, log_det = y_pred
def compute_loss(self, x: Tensor = None, **kwargs):
z, log_det = self(x, inverse=False, jacobian=True, **kwargs)
log_prob = self.base_distribution.log_prob(z)
nll = -keras.ops.mean(log_prob + log_det)

return nll

def compute_metrics(self, x, y, y_pred, **kwargs):
return {}
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@

import keras
from keras.saving import (
deserialize_keras_object,
register_keras_serializable,
serialize_keras_object,
)

from bayesflow.experimental.types import Tensor
Expand All @@ -13,51 +11,15 @@

@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class DualCoupling(InvertibleLayer):
def __init__(self, coupling1: SingleCoupling, coupling2: SingleCoupling, pivot: int = None, **kwargs):
super().__init__(**kwargs)
self.coupling1 = coupling1
self.coupling2 = coupling2
self.pivot = pivot

@classmethod
def new(cls, *args, **kwargs) -> "DualCoupling":
""" Construct a new DualCoupling from hyperparameters. """
coupling1 = SingleCoupling.new(*args, **kwargs)
coupling2 = SingleCoupling.new(*args, **kwargs)

return cls(coupling1, coupling2, **kwargs)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "DualCoupling":
coupling1 = deserialize_keras_object(config.pop("coupling1"))
coupling2 = deserialize_keras_object(config.pop("coupling2"))
pivot = config.pop("pivot")

return cls(coupling1, coupling2, pivot=pivot, **config)

def get_config(self) -> dict:
base_config = super().get_config()

config = {
"coupling1": serialize_keras_object(self.coupling1),
"coupling2": serialize_keras_object(self.coupling2),
"pivot": self.pivot,
}

return base_config | config
def __init__(self, subnet: str = "resnet", transform: str = "affine"):
super().__init__()
self.coupling1 = SingleCoupling(subnet, transform)
self.coupling2 = SingleCoupling(subnet, transform)
self.pivot = None

def build(self, input_shape):
self.pivot = input_shape[-1] // 2

x1_shape = list(input_shape)
x2_shape = list(input_shape)

x1_shape[-1] = self.pivot
x2_shape[-1] = input_shape[-1] - self.pivot

self.coupling1.build((x1_shape, x2_shape))
self.coupling2.build((x2_shape, x1_shape))

def call(self, xz: Tensor, conditions: any = None, inverse: bool = False) -> (Tensor, Tensor):
if inverse:
return self._inverse(xz, conditions=conditions)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,57 +1,30 @@

import keras
from keras.saving import (
deserialize_keras_object,
register_keras_serializable,
serialize_keras_object
register_keras_serializable
)

from bayesflow.experimental.types import Tensor
from ..invertible_layer import InvertibleLayer
from ..subnets import find_subnet
from ..transforms import find_transform, Transform
from ..transforms import find_transform


@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class SingleCoupling(InvertibleLayer):
def __init__(self, subnet: keras.Layer, transform: Transform, **kwargs):
super().__init__(**kwargs)
self.subnet = subnet
self.transform = transform

@classmethod
def new(cls, subnet: str = "resnet", transform: str = "affine", **kwargs) -> "SingleCoupling":
transform = find_transform(transform)
subnet = find_subnet(subnet)

return cls(subnet, transform, **kwargs)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "SingleCoupling":
subnet = deserialize_keras_object(config.pop("subnet"))
transform = deserialize_keras_object(config.pop("transform"))

return cls(subnet, transform, **config)
"""
Implements a single coupling layer as a composition of a subnet and a transform.
def get_config(self) -> dict:
base_config = super().get_config()

config = {
"subnet": serialize_keras_object(self.subnet),
"transform": serialize_keras_object(self.transform),
}

return base_config | config
Subnet output tensors are linearly mapped to the correct dimension.
"""
def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(None)
self.subnet = find_subnet(subnet)
self.transform = find_transform(transform)

def build(self, input_shape):
# TODO: this is not ideal...
if not hasattr(self.subnet, "build_output"):
return

x1_shape, x2_shape = input_shape
x2_shape = list(x2_shape)
x2_shape[-1] = self.transform.params_per_dim * x2_shape[-1]
self.subnet.build_output(x2_shape)
self.dense.units = self.transform.params_per_dim * input_shape[-1]

def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor):
if inverse:
Expand Down Expand Up @@ -79,7 +52,7 @@ def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]:
if keras.ops.is_tensor(conditions):
x = keras.ops.concatenate([x, conditions], axis=-1)

parameters = self.subnet(x)
parameters = self.dense(self.subnet(x))
parameters = self.transform.split_parameters(parameters)
parameters = self.transform.constrain_parameters(parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def find_subnet(subnet: str | keras.Layer | Callable, **kwargs) -> keras.Layer:
case str() as name:
match name.lower():
case "resnet":
resnet = networks.ResNet.new(**kwargs)
resnet.output_layer.units = None
return resnet
return networks.ResNet(**kwargs)
case other:
raise NotImplementedError(f"Unsupported subnet name: '{other}'.")
case keras.Layer() as layer:
Expand Down
20 changes: 11 additions & 9 deletions bayesflow/experimental/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ def get_config(self) -> dict:
def build(self, input_shape):
self.network.build(input_shape)

def train_step(self, data):
# hack to avoid the call method in super().train_step()
# maybe you have a better idea? Seems the train_step is not backend-agnostic since it requires gradient tracking
call = self.call
self.call = lambda *args, **kwargs: None
super().train_step(data)
self.call = call

def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
# implement conditions = None and jacobian = False first
# then work your way up
Expand All @@ -62,7 +54,17 @@ def compute_loss(self, x=None, **kwargs):
# x should ideally contain both x0 and x1,
# where the optimal transport matching already happened in the worker process
# this is possible, but might not be super user-friendly. We will have to see.
raise NotImplementedError
x0, x1, t = x

xt = t * x1 + (1 - t) * x0

# get velocity at xt
v = ...

# target velocity:
vstar = x1 - x0

# return mse between v and vstar


# TODO: see below for reference implementation
Expand Down
Loading

0 comments on commit cc69a38

Please sign in to comment.