diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py deleted file mode 100644 index d09a7ed3..00000000 --- a/bayesflow/experimental/networks/coupling_flow/couplings/coupling.py +++ /dev/null @@ -1,59 +0,0 @@ - -import keras -from keras import ops - -from ..transforms import find_transform -from ..subnets import find_subnet - - -class Coupling(keras.Layer): - """ Implements a single coupling layer that transforms half of its input through a coupling transform.""" - def __init__( - self, - subnet_builder: str, - half_dim: int, - transform: str, - **kwargs - ): - super().__init__() - - self.transform = find_transform(transform, **kwargs.pop("transform_settings", {})) - self.half_dim = half_dim - self.subnet = find_subnet( - subnet=subnet_builder, - transform=self.transform, - output_dim=half_dim, - **kwargs.pop("subnet_settings", {}) - ) - - def call(self, x, c=None, forward=True, **kwargs): - if forward: - return self.forward(x, c, **kwargs) - return self.inverse(x, c) - - def forward(self, x, c=None, **kwargs): - - x1, x2 = x[..., :self.half_dim], x[..., self.half_dim:] - z2 = x2 - parameters = self.get_parameters(x2, c, **kwargs) - z1, log_det = self.transform.forward(x1, parameters) - z = ops.concatenate([z1, z2], axis=-1) - return z, log_det - - def inverse(self, z, c=None): - z1, z2 = z[..., :self.half_dim], z[..., self.half_dim:] - x2 = z2 - parameters = self.get_parameters(x2, c) - x1, log_det = self.transform.inverse(z1, parameters) - x = ops.concatenate([x1, x2], axis=-1) - return x, log_det - - def get_parameters(self, x, c=None, **kwargs): - if c is not None: - x = ops.concatenate([x, c], axis=-1) - - parameters = self.subnet(x, **kwargs) - parameters = self.transform.split_parameters(parameters) - parameters = self.transform.constrain_parameters(parameters) - - return parameters diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py new file mode 100644 index 00000000..01e68dbb --- /dev/null +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -0,0 +1,60 @@ + +import keras +from keras.saving import ( + register_keras_serializable +) + +from bayesflow.experimental.types import Tensor +from bayesflow.experimental.utils import find_network +from ..invertible_layer import InvertibleLayer +from ..transforms import find_transform + + +@register_keras_serializable(package="bayesflow.networks.coupling_flow") +class SingleCoupling(InvertibleLayer): + """ + Implements a single coupling layer as a composition of a subnet and a transform. + + Subnet output tensors are linearly mapped to the correct dimension. + """ + def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros") + self.network = find_network(network) + self.transform = find_transform(transform) + + # noinspection PyMethodOverriding + def build(self, x1_shape, x2_shape): + self.dense.units = self.transform.params_per_dim * x2_shape[-1] + + def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor): + if inverse: + return self._inverse(x1, x2, conditions=conditions) + return self._forward(x1, x2, conditions=conditions) + + def _forward(self, x1: Tensor, x2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor): + """ Transform (x1, x2) -> (x1, f(x2; x1)) """ + z1 = x1 + parameters = self.get_parameters(x1, conditions) + z2, log_det = self.transform(x2, parameters=parameters) + + return (z1, z2), log_det + + def _inverse(self, z1: Tensor, z2: Tensor, conditions: any = None) -> ((Tensor, Tensor), Tensor): + """ Transform (x1, f(x2; x1)) -> (x1, x2) """ + x1 = z1 + parameters = self.get_parameters(x1, conditions) + x2, log_det = self.transform(z2, parameters=parameters, inverse=True) + + return (x1, x2), log_det + + def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]: + # TODO: pass conditions to subnet via kwarg if possible + if keras.ops.is_tensor(conditions): + x = keras.ops.concatenate([x, conditions], axis=-1) + + parameters = self.dense(self.network(x)) + parameters = self.transform.split_parameters(parameters) + parameters = self.transform.constrain_parameters(parameters) + + return parameters diff --git a/bayesflow/experimental/networks/coupling_flow/permutations/__init__.py b/bayesflow/experimental/networks/coupling_flow/permutations/__init__.py index 6bb339ae..6e6bbd3c 100644 --- a/bayesflow/experimental/networks/coupling_flow/permutations/__init__.py +++ b/bayesflow/experimental/networks/coupling_flow/permutations/__init__.py @@ -1,3 +1,4 @@ from .orthogonal import OrthogonalPermutation +from .fixed_permutation import FixedPermutation from .swap import Swap diff --git a/bayesflow/experimental/networks/coupling_flow/permutations/fixed_permutation.py b/bayesflow/experimental/networks/coupling_flow/permutations/fixed_permutation.py index c9ef95d4..29ecb8c6 100644 --- a/bayesflow/experimental/networks/coupling_flow/permutations/fixed_permutation.py +++ b/bayesflow/experimental/networks/coupling_flow/permutations/fixed_permutation.py @@ -15,7 +15,7 @@ def __init__(self, forward_indices=None, inverse_indices=None, **kwargs): self.forward_indices = forward_indices self.inverse_indices = inverse_indices - def call(self, xz: Tensor, inverse: bool = False): + def call(self, xz: Tensor, inverse: bool = False, **kwargs): if inverse: return self._inverse(xz) return self._forward(xz) @@ -25,10 +25,10 @@ def build(self, input_shape: Shape) -> None: def _forward(self, x: Tensor) -> (Tensor, Tensor): z = keras.ops.take(x, self.forward_indices, axis=-1) - log_det = keras.ops.zeros(keras.ops.shape(x)[0]) + log_det = 0. return z, log_det def _inverse(self, z: Tensor) -> (Tensor, Tensor): x = keras.ops.take(z, self.inverse_indices, axis=-1) - log_det = keras.ops.zeros(keras.ops.shape(z)[0]) + log_det = 0. return x, log_det diff --git a/bayesflow/experimental/networks/coupling_flow/permutations/orthogonal.py b/bayesflow/experimental/networks/coupling_flow/permutations/orthogonal.py index e4b458a4..b38cd51e 100644 --- a/bayesflow/experimental/networks/coupling_flow/permutations/orthogonal.py +++ b/bayesflow/experimental/networks/coupling_flow/permutations/orthogonal.py @@ -1,10 +1,14 @@ from keras import ops +from keras.saving import ( + register_keras_serializable, +) from bayesflow.experimental.types import Shape, Tensor from ..invertible_layer import InvertibleLayer +@register_keras_serializable(package="bayesflow.networks.coupling_flow") class OrthogonalPermutation(InvertibleLayer): """Implements a learnable orthogonal transformation according to [1]. Can be used as an alternative to a fixed ``Permutation`` layer. @@ -25,7 +29,7 @@ def build(self, input_shape: Shape) -> None: trainable=True ) - def call(self, xz: Tensor, inverse: bool = False): + def call(self, xz: Tensor, inverse: bool = False, **kwargs): if inverse: return self._inverse(xz) return self._forward(xz) diff --git a/bayesflow/experimental/networks/coupling_flow/permutations/random.py b/bayesflow/experimental/networks/coupling_flow/permutations/random.py index 2aadcfbd..02f01f00 100644 --- a/bayesflow/experimental/networks/coupling_flow/permutations/random.py +++ b/bayesflow/experimental/networks/coupling_flow/permutations/random.py @@ -14,13 +14,13 @@ def build(self, input_shape: Shape) -> None: forward_indices = keras.random.shuffle(keras.ops.arange(input_shape[-1])) inverse_indices = keras.ops.argsort(forward_indices) - self.forward_indices = self.add_variable( + self.forward_indices = self.add_weight( shape=(input_shape[-1],), initializer=keras.initializers.Constant(forward_indices), trainable=False ) - self.inverse_indices = self.add_variable( + self.inverse_indices = self.add_weight( shape=(input_shape[-1],), initializer=keras.initializers.Constant(inverse_indices), trainable=False diff --git a/bayesflow/experimental/networks/coupling_flow/permutations/swap.py b/bayesflow/experimental/networks/coupling_flow/permutations/swap.py index c8d5908e..d3a4ba2a 100644 --- a/bayesflow/experimental/networks/coupling_flow/permutations/swap.py +++ b/bayesflow/experimental/networks/coupling_flow/permutations/swap.py @@ -1,5 +1,8 @@ import keras +from keras.saving import ( + register_keras_serializable, +) from bayesflow.experimental.types import Shape from .fixed_permutation import FixedPermutation diff --git a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py b/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py deleted file mode 100644 index 4dbee6d6..00000000 --- a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ - -from typing import Callable - -import keras - -import bayesflow.experimental.networks as networks - - -def find_subnet(subnet: str | keras.Layer | Callable, **kwargs) -> keras.Layer: - """ Find subnetworks by name and configure them to use lazy in- and output dimensions. """ - match subnet: - case str() as name: - match name.lower(): - case "resnet": - return networks.ResNet(**kwargs) - case other: - raise NotImplementedError(f"Unsupported subnet name: '{other}'.") - case keras.Layer() as layer: - return layer - case callable() as constructor: - return constructor(**kwargs) - case other: - raise NotImplementedError(f"Cannot infer subnet from {other!r}.") diff --git a/bayesflow/experimental/networks/flow_matching/flow_matching.py b/bayesflow/experimental/networks/flow_matching/flow_matching.py index 71c81ded..54d93097 100644 --- a/bayesflow/experimental/networks/flow_matching/flow_matching.py +++ b/bayesflow/experimental/networks/flow_matching/flow_matching.py @@ -3,139 +3,179 @@ import keras from keras.saving import ( - deserialize_keras_object, register_keras_serializable, - serialize_keras_object, ) from scipy.integrate import solve_ivp -from bayesflow.experimental.types import Shape, Tensor +from bayesflow.experimental.types import Tensor +from bayesflow.experimental.utils import find_network from ..inference_network import InferenceNetwork @register_keras_serializable(package="bayesflow.networks") class FlowMatching(InferenceNetwork): - def __init__(self, network: keras.Layer, **kwargs): + def __init__(self, network: str = "resnet", **kwargs): super().__init__(**kwargs) - self.network = network - - @classmethod - def new(cls, network: str = "resnet", base_distribution: str = "normal"): - # TODO: we probably want to provide a factory method like this, since the other networks use it - # for high-level input parameters - # network = find_network(network) - return cls(network, base_distribution=base_distribution) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "FlowMatching": - # TODO: the base distribution must be savable and loadable - # ideally we also don't want to have to manually deserialize it in every subclass of InferenceNetwork - base_distribution = deserialize_keras_object(config.pop("base_distribution")) - network = deserialize_keras_object(config.pop("network")) - return cls(network, base_distribution=base_distribution, **config) - - def get_config(self) -> dict: - base_config = super().get_config() - config = {"network": serialize_keras_object(self.network)} - return base_config | config - - def build(self, input_shape): - self.network.build(input_shape) + self.network = find_network(network) - def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: - # implement conditions = None and jacobian = False first - # then work your way up - raise NotImplementedError - - def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: - raise NotImplementedError - - def compute_loss(self, x=None, **kwargs): - # x should ideally contain both x0 and x1, - # where the optimal transport matching already happened in the worker process - # this is possible, but might not be super user-friendly. We will have to see. - x0, x1, t = x - - xt = t * x1 + (1 - t) * x0 - - # get velocity at xt - v = ... - - # target velocity: - vstar = x1 - x0 - - # return mse between v and vstar - - -# TODO: see below for reference implementation - - -class FlowMatching(keras.Model): - def __init__(self, network: keras.Layer, base_distribution): - super().__init__() - self.network = network - self.base_distribution = find_distribution(base_distribution) - - def call(self, inferred_variables, inference_conditions): - return self.network(keras.ops.concatenate([inferred_variables, inference_conditions], axis=1)) - - def compute_loss(self, x=None, y=None, y_pred=None, **kwargs): - return keras.losses.mean_squared_error(y, y_pred) - - def velocity(self, x: Tensor, t: Tensor, c: Tensor = None): - if c is None: + def velocity(self, x: Tensor, t: Tensor, conditions: any = None): + if conditions is None: xtc = keras.ops.concatenate([x, t], axis=1) else: - xtc = keras.ops.concatenate([x, t, c], axis=1) + xtc = keras.ops.concatenate([x, t, conditions], axis=1) return self.network(xtc) - def forward(self, x, c=None, method="RK45") -> Tensor: - def f(t, x): + def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: + def dfdt(t: float, x: Tensor): t = keras.ops.full((keras.ops.shape(x)[0], 1), t) - return self.velocity(x, t, c) - - bunch = solve_ivp(f, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True) + return self.velocity(x, t, conditions) - return bunch[1] + return solve_ivp(dfdt, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True)[1] - def inverse(self, x, c=None, method="RK45") -> Tensor: - def f(t, x): + def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: + def dfdt(t: float, x: Tensor): t = keras.ops.full((keras.ops.shape(x)[0], 1), t) - return self.velocity(x, t, c) - - bunch = solve_ivp(f, t_span=(0.0, 1.0), y0=x, method=method, vectorized=True) - - return bunch[1] + return self.velocity(x, t, conditions) - def sample(self, batch_shape: Shape) -> Tensor: - z = self.base_distribution.sample(batch_shape) - return self.inverse(z) + return solve_ivp(dfdt, t_span=(0.0, 1.0), y0=z, method=method, vectorized=True)[1] - def log_prob(self, x: Tensor, c: Tensor = None) -> Tensor: - raise NotImplementedError(f"Keras does not yet support backend-agnostic Vector-Jacobian Products.") - - -def hutchinson_trace(f: callable, x: Tensor) -> (Tensor, Tensor): - # TODO: test this for all 3 backends - noise = keras.random.normal(keras.ops.shape(x)) - - match keras.backend.backend(): - case "jax": - import jax - fx, jvp = jax.jvp(f, (x,), (noise,)) - case "tensorflow": - import tensorflow as tf - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - fx = f(x) - jvp = tape.gradient(fx, x, output_gradients=noise) - case "torch": - import torch - fx, jvp = torch.autograd.functional.jvp(f, x, noise, create_graph=True) - case other: - raise NotImplementedError(f"Backend {other} is not supported for trace estimation.") - - trace = keras.ops.sum(jvp * noise, axis=1) - - return fx, trace + def compute_loss(self, x=None, **kwargs): + x0, x1, *conditions = x + t = keras.random.uniform((keras.ops.shape(x0)[0], 1)) + + x = t * x1 + (1 - t) * x0 + xtc = keras.ops.concatenate([x, t, *conditions], axis=-1) + + predicted_velocity = self.network(xtc) + target_velocity = x1 - x0 + + return keras.losses.mean_squared_error(predicted_velocity, target_velocity) + + +# @register_keras_serializable(package="bayesflow.networks") +# class FlowMatching(InferenceNetwork): +# def __init__(self, network: keras.Layer, **kwargs): +# super().__init__(**kwargs) +# self.network = network +# +# @classmethod +# def new(cls, network: str = "resnet", base_distribution: str = "normal"): +# # TODO: we probably want to provide a factory method like this, since the other networks use it +# # for high-level input parameters +# # network = find_network(network) +# return cls(network, base_distribution=base_distribution) +# +# @classmethod +# def from_config(cls, config: dict, custom_objects=None) -> "FlowMatching": +# # TODO: the base distribution must be savable and loadable +# # ideally we also don't want to have to manually deserialize it in every subclass of InferenceNetwork +# base_distribution = deserialize_keras_object(config.pop("base_distribution")) +# network = deserialize_keras_object(config.pop("network")) +# return cls(network, base_distribution=base_distribution, **config) +# +# def get_config(self) -> dict: +# base_config = super().get_config() +# config = {"network": serialize_keras_object(self.network)} +# return base_config | config +# +# def build(self, input_shape): +# self.network.build(input_shape) +# +# def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: +# # implement conditions = None and jacobian = False first +# # then work your way up +# raise NotImplementedError +# +# def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: +# raise NotImplementedError +# +# def compute_loss(self, x=None, **kwargs): +# # x should ideally contain both x0 and x1, +# # where the optimal transport matching already happened in the worker process +# # this is possible, but might not be super user-friendly. We will have to see. +# x0, x1, t = x +# +# xt = t * x1 + (1 - t) * x0 +# +# # get velocity at xt +# v = ... +# +# # target velocity: +# vstar = x1 - x0 +# +# # return mse between v and vstar +# +# +# # TODO: see below for reference implementation +# +# +# class FlowMatching(keras.Model): +# def __init__(self, network: keras.Layer, base_distribution): +# super().__init__() +# self.network = network +# self.base_distribution = find_distribution(base_distribution) +# +# def call(self, inferred_variables, inference_conditions): +# return self.network(keras.ops.concatenate([inferred_variables, inference_conditions], axis=1)) +# +# def compute_loss(self, x=None, y=None, y_pred=None, **kwargs): +# return keras.losses.mean_squared_error(y, y_pred) +# +# def velocity(self, x: Tensor, t: Tensor, c: Tensor = None): +# if c is None: +# xtc = keras.ops.concatenate([x, t], axis=1) +# else: +# xtc = keras.ops.concatenate([x, t, c], axis=1) +# +# return self.network(xtc) +# +# def forward(self, x, c=None, method="RK45") -> Tensor: +# def f(t, x): +# t = keras.ops.full((keras.ops.shape(x)[0], 1), t) +# return self.velocity(x, t, c) +# +# bunch = solve_ivp(f, t_span=(1.0, 0.0), y0=x, method=method, vectorized=True) +# +# return bunch[1] +# +# def inverse(self, x, c=None, method="RK45") -> Tensor: +# def f(t, x): +# t = keras.ops.full((keras.ops.shape(x)[0], 1), t) +# return self.velocity(x, t, c) +# +# bunch = solve_ivp(f, t_span=(0.0, 1.0), y0=x, method=method, vectorized=True) +# +# return bunch[1] +# +# def sample(self, batch_shape: Shape) -> Tensor: +# z = self.base_distribution.sample(batch_shape) +# return self.inverse(z) +# +# def log_prob(self, x: Tensor, c: Tensor = None) -> Tensor: +# raise NotImplementedError(f"Keras does not yet support backend-agnostic Vector-Jacobian Products.") +# +# +# def hutchinson_trace(f: callable, x: Tensor) -> (Tensor, Tensor): +# # TODO: test this for all 3 backends +# noise = keras.random.normal(keras.ops.shape(x)) +# +# match keras.backend.backend(): +# case "jax": +# import jax +# fx, jvp = jax.jvp(f, (x,), (noise,)) +# case "tensorflow": +# import tensorflow as tf +# with tf.GradientTape(persistent=True) as tape: +# tape.watch(x) +# fx = f(x) +# jvp = tape.gradient(fx, x, output_gradients=noise) +# case "torch": +# import torch +# fx, jvp = torch.autograd.functional.jvp(f, x, noise, create_graph=True) +# case other: +# raise NotImplementedError(f"Backend {other} is not supported for trace estimation.") +# +# trace = keras.ops.sum(jvp * noise, axis=1) +# +# return fx, trace diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index ca9e4825..3f29a718 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -6,8 +6,8 @@ register_keras_serializable, ) -from bayesflow.experimental.distributions import find_distribution from bayesflow.experimental.types import Tensor +from bayesflow.experimental.utils import find_distribution @register_keras_serializable(package="bayesflow.networks") diff --git a/bayesflow/experimental/rectifiers.py b/bayesflow/experimental/rectifiers.py deleted file mode 100644 index be59aa73..00000000 --- a/bayesflow/experimental/rectifiers.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright (c) 2022 The BayesFlow Developers - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from functools import partial - -import tensorflow as tf -import tensorflow_probability as tfp - -import bayesflow.default_settings as defaults -from bayesflow.computational_utilities import compute_jacobian_trace -from bayesflow.exceptions import SummaryStatsError -from bayesflow.helper_networks import MCDropout -from bayesflow.losses import mmd_summary_space - - -class DriftNetwork(tf.keras.Model): - """Implements a learnable velocity field for a neural ODE. Will typically be used - in conjunction with a ``RectifyingFlow`` instance, as proposed by [1] in the context - of unconditional image generation. - - [1] Liu, X., Gong, C., & Liu, Q. (2022). - Flow straight and fast: Learning to generate and transfer data with rectified flow. - arXiv preprint arXiv:2209.03003. - """ - - def __init__( - self, target_dim, num_dense=3, dense_args=None, dropout=True, mc_dropout=False, dropout_prob=0.05, **kwargs - ): - """Creates a learnable velocity field instance to be used in the context of rectifying - flows or neural ODEs. - - [1] Liu, X., Gong, C., & Liu, Q. (2022). - Flow straight and fast: Learning to generate and transfer data with rectified flow. - arXiv preprint arXiv:2209.03003. - - Parameters - ---------- - target_dim : int - The problem dimensionality (e.g., in parameter estimation, the number of parameters) - num_dense : int, optional, default: 3 - The number of hidden layers for the inner fully-connected network - dense_args : dict or None, optional, default: None - The arguments to be passed to ``tf.keras.layers.Dense`` constructor. If None, default settings - will be fetched from ``bayesflow.default_settings``. - dropout : bool, optional, default: True - Whether to use dropout in-between the hidden layers. - mc_dropout : bool, optional, default: False - Whether to use dropout Monte Carlo dropout (i.e., Bayesian approximation) during inference - dropout_prob : float in (0, 1), optional, default: 0.05 - The dropout probability. Only has effecft if ``dropout=True`` or ``mc_dropout=True`` - **kwargs : dict, optional, default: {} - Optional keyword arguments passed to the ``tf.keras.Model.__init__`` method. - """ - - super().__init__(**kwargs) - - self.latent_dim = target_dim - if dense_args is None: - dense_args = defaults.DEFAULT_SETTING_DENSE_RECT - self.net = tf.keras.Sequential() - for _ in range(num_dense): - self.net.add(tf.keras.layers.Dense(**dense_args)) - if mc_dropout: - self.net.add(MCDropout(dropout_prob)) - elif dropout: - self.net.add(tf.keras.layers.Dropout(dropout_prob)) - else: - pass - self.net.add(tf.keras.layers.Dense(self.latent_dim)) - self.net.build(input_shape=()) - - def call(self, target_vars, latent_vars, time, condition, **kwargs): - """Performs a linear interpolation between target and latent variables - over time (i.e., a single ODE step during training). - - Parameters - ---------- - target_vars : tf.Tensor of shape (batch_size, ..., num_targets) - The variables of interest (e.g., parameters) over which we perform inference. - latent_vars : tf.Tensor of shape (batch_size, ..., num_targets) - The sampled random variates from the base distribution. - time : tf.Tensor of shape (batch_size, ..., 1) - A vector of time indices in (0, 1) - condition : tf.Tensor of shape (batch_size, ..., condition_dim) - The optional conditioning variables (e.g., as returned by a summary network) - **kwargs : dict, optional, default: {} - Optional keyword arguments passed to the ``tf.keras.Model`` call() method - """ - - diff = target_vars - latent_vars - wdiff = time * target_vars + (1 - time) * latent_vars - drift = self.drift(wdiff, time, condition, **kwargs) - return diff, drift - - def drift(self, target_t, time, condition, **kwargs): - """Returns the drift at target_t time given optional condition(s). - - Parameters - ---------- - target_t : tf.Tensor of shape (batch_size, ..., num_targets) - The variables of interest (e.g., parameters) over which we perform inference. - time : tf.Tensor of shape (batch_size, ..., 1) - A vector of time indices in (0, 1) - condition : tf.Tensor of shape (batch_size, ..., condition_dim) - The optional conditioning variables (e.g., as returned by a summary network) - **kwargs : dict, optional, default: {} - Optional keyword arguments passed to the drift network. - """ - - if condition is not None: - inp = tf.concat([target_t, condition, time], axis=-1) - else: - inp = tf.concat([target_t, time], axis=-1) - return self.net(inp, **kwargs) - - -class RectifiedDistribution(tf.keras.Model): - """Implements a rectifying flows according to [1]. To be used as an alternative - to a normalizing flow in a BayesFlow pipeline. - - [1] Liu, X., Gong, C., & Liu, Q. (2022). - Flow straight and fast: Learning to generate and transfer data with rectified flow. - arXiv preprint arXiv:2209.03003. - """ - - def __init__(self, drift_net, summary_net=None, latent_dist=None, loss_fun=None, summary_loss_fun=None, **kwargs): - """Initializes a composite neural network to represent an amortized approximate posterior through - for a rectifying flow. - - Parameters - ---------- - drift_net : tf.keras.Model - A neural network for the velocity field (drift) of the learnable ODE - summary_net : tf.keras.Model or None, optional, default: None - An optional summary network to compress non-vector data structures. - latent_dist : callable or None, optional, default: None - The latent distribution towards which to optimize the networks. Defaults to - a multivariate unit Gaussian. - loss_fun : callable or None, optional, default: None - The loss function for "rectifying" the velocity field. If ``None``, defaults - to tf.keras.losses.logcosh. Sensible alternatives are MSE (as in []) - summary_loss_fun : callable, str, or None, optional, default: None - The loss function which accepts the outputs of the summary network. If ``None``, no loss is provided - and the summary space will not be shaped according to a known distribution (see [2]). - If ``summary_loss_fun='MMD'``, the default loss from [2] will be used. - **kwargs : dict, optional, default: {} - Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance. - - Important - ---------- - - If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain - any ``summary_conditions``, i.e., ``summary_conditions`` should be set to ``None``, otherwise these will be ignored. - """ - - super().__init__(**kwargs) - - self.drift_net = drift_net - self.summary_net = summary_net - self.latent_dim = drift_net.latent_dim - self.latent_dist = self._determine_latent_dist(latent_dist) - self.loss_fun = self._determine_loss(loss_fun) - self.summary_loss = self._determine_summary_loss(summary_loss_fun) - - def call(self, input_dict, return_summary=False, num_eval_points=1, **kwargs): - """Performs a forward pass through the summary and drift network given an input dictionary. - - Parameters - ---------- - input_dict : dict - Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: - ``targets`` - the latent model parameters over which a condition density is learned - ``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network - ``direct_conditions`` - the conditioning variables that the directly passed to the inference network - return_summary : bool, optional, default: False - A flag which determines whether the learnable data summaries (representations) are returned or not. - num_eval_points : int, optional, default: 1 - The number of time points for evaluating the noisy estimator. Values larger than the default 1 - may reduce the variance of the estimator, but may lead to increased memory demands, since an - additional dimension is added at axis 1 of all tensors. - **kwargs : dict, optional, default: {} - Additional keyword arguments passed to the networks - For instance, ``kwargs={'training': True}`` is passed automatically during training. - - Returns - ------- - net_out or (net_out, summary_out) - """ - - # Concatenate conditions, if given - summary_out, full_cond = self._compute_summary_condition( - input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), - input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), - **kwargs, - ) - - # Extract target variables - target_vars = input_dict[defaults.DEFAULT_KEYS["parameters"]] - - # Extract batch size (autograph friendly) - batch_size = tf.shape(target_vars)[0] - - # Sample latent variables - latent_vars = self.latent_dist.sample(batch_size) - - # Do a little trick for less noisy estimator, if evals > 1 - if num_eval_points > 1: - target_vars = tf.stack([target_vars] * num_eval_points, axis=1) - latent_vars = tf.stack([latent_vars] * num_eval_points, axis=1) - full_cond = tf.stack([full_cond] * num_eval_points, axis=1) - # Sample time - time = tf.random.uniform((batch_size, num_eval_points, 1)) - else: - time = tf.random.uniform((batch_size, 1)) - - # Compute drift - net_out = self.drift_net(target_vars, latent_vars, time, full_cond, **kwargs) - - # Return summary outputs or not, depending on parameter - if return_summary: - return net_out, summary_out - return net_out - - def compute_loss(self, input_dict, **kwargs): - """Computes the loss of the posterior amortizer given an input dictionary, which will - typically be the output of a Bayesian ``GenerativeModel`` instance. - - Parameters - ---------- - input_dict : dict - Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: - ``targets`` - the latent variables over which a condition density is learned - ``summary_conditions`` - the conditioning variables that are first passed through a summary network - ``direct_conditions`` - the conditioning variables that the directly passed to the inference network - **kwargs : dict, optional, default: {} - Additional keyword arguments passed to the networks - For instance, ``kwargs={'training': True}`` is passed automatically during training. - - Returns - ------- - total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables - """ - - net_out, sum_out = self(input_dict, return_summary=True, **kwargs) - diff, drift = net_out - loss = self.loss_fun(diff, drift) - - # Case summary loss should be computed - if self.summary_loss is not None: - sum_loss = self.summary_loss(sum_out) - # Case no summary loss, simply add 0 for convenience - else: - sum_loss = 0.0 - - # Compute and return total loss - total_loss = tf.reduce_mean(loss) + sum_loss - return total_loss - - def sample(self, input_dict, n_samples, to_numpy=True, step_size=1e-3, **kwargs): - """Generates random draws from the approximate posterior given a dictionary with conditonal variables. - - Parameters - ---------- - input_dict : dict - Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: - ``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network - ``direct_conditions`` : the conditioning variables that the directly passed to the inference network - n_samples : int - The number of posterior draws (samples) to obtain from the approximate posterior - to_numpy : bool, optional, default: True - Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor`` - step_size : float, optional, default: 0.01 - The step size for the stochastic Euler solver. - **kwargs : dict, optional, default: {} - Additional keyword arguments passed to the networks - - Returns - ------- - post_samples : tf.Tensor or np.ndarray of shape (n_data_sets, n_samples, n_params) - The sampled parameters from the approximate posterior of each data set - """ - - # Compute condition (direct, summary, or both) - _, conditions = self._compute_summary_condition( - input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), - input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), - training=False, - **kwargs, - ) - n_data_sets = tf.shape(conditions)[0] - - # Sample initial latent variables -> shape (n_data_sets, n_samples, latent_dim) - latent_vars = self.latent_dist.sample((n_data_sets, n_samples)) - - # Replicate conditions and solve ODEs simulatenously - conditions = tf.stack([conditions] * n_samples, axis=1) - post_samples = self._solve_euler(latent_vars, conditions, step_size, **kwargs) - - # Remove trailing first dimension in the single data case - if n_data_sets == 1: - post_samples = tf.squeeze(post_samples, axis=0) - - # Return numpy version of tensor or tensor itself - if to_numpy: - return post_samples.numpy() - return post_samples - - def log_density(self, input_dict, to_numpy=True, step_size=1e-3, **kwargs): - """Computes the log density...""" - - # Compute condition (direct, summary, or both) - _, conditions = self._compute_summary_condition( - input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), - input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), - training=False, - **kwargs, - ) - - # Extract targets - target_vars = input_dict[defaults.DEFAULT_KEYS["parameters"]] - - # Reverse ODE and log pdf computation with the trace method - latents, trace = self._solve_euler_inv(target_vars, conditions, step_size, **kwargs) - lpdf = self.latent_dist.log_prob(latents) + trace - - # Return numpy version of tensor or tensor itself - if to_numpy: - return lpdf.numpy() - return lpdf - - def _solve_euler(self, latent_vars, condition, dt=1e-3, **kwargs): - """Simple stochastic parallel Euler solver.""" - - num_steps = int(1 / dt) - time_vec = tf.zeros((tf.shape(latent_vars)[0], tf.shape(latent_vars)[1], 1)) - target = tf.identity(latent_vars) - for _ in range(num_steps + 1): - target += self.drift_net.drift(target, time_vec, condition, **kwargs) * dt - time_vec += dt - return target - - def _solve_euler_inv(self, targets, condition, dt=1e-3, **kwargs): - """Solves the reverse ODE (negative direction of drift) and returns the trace.""" - - def velocity(latents, drift, time_vec, condition, **kwargs): - v = drift(latents, time_vec, condition, **kwargs) - return v - - batch_size = tf.shape(targets)[0] - num_samples = tf.shape(targets)[1] - num_steps = int(1 / dt) - time_vec = tf.ones((batch_size, num_samples, 1)) - trace = tf.zeros((batch_size, num_samples)) - latents = tf.identity(targets) - for _ in range(num_steps + 1): - f = partial(velocity, drift=self.drift_net.drift, time_vec=time_vec, condition=condition) - drift_t, trace_t = compute_jacobian_trace(f, latents, **kwargs) - latents -= drift_t * dt - trace -= trace_t * dt - time_vec -= dt - return latents, trace - - def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs): - """Determines how to concatenate the provided conditions.""" - - # Compute learnable summaries, if given - if self.summary_net is not None: - sum_condition = self.summary_net(summary_conditions, **kwargs) - else: - sum_condition = None - - # Concatenate learnable summaries with fixed summaries - if sum_condition is not None and direct_conditions is not None: - full_cond = tf.concat([sum_condition, direct_conditions], axis=-1) - elif sum_condition is not None: - full_cond = sum_condition - elif direct_conditions is not None: - full_cond = direct_conditions - else: - raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...") - return sum_condition, full_cond - - def _determine_latent_dist(self, latent_dist): - """Determines which latent distribution to use and defaults to unit normal if ``None`` provided.""" - - if latent_dist is None: - return tfp.distributions.MultivariateNormalDiag(loc=[0.0] * self.latent_dim) - else: - return latent_dist - - def _determine_summary_loss(self, loss_fun): - """Determines which summary loss to use if default `None` argument provided, otherwise return identity.""" - - # If callable, return provided loss - if loss_fun is None or callable(loss_fun): - return loss_fun - - # If string, check for MMD or mmd - elif type(loss_fun) is str: - if loss_fun.lower() == "mmd": - return mmd_summary_space - else: - raise NotImplementedError("For now, only 'mmd' is supported as a string argument for summary_loss_fun!") - # Throw if loss type unexpected - else: - raise NotImplementedError( - "Could not infer summary_loss_fun, argument should be of type (None, callable, or str)!" - ) - - def _determine_loss(self, loss_fun): - """Determines which summary loss to use if default ``None`` argument provided, otherwise return identity.""" - - if loss_fun is None: - return tf.keras.losses.log_cosh - return loss_fun diff --git a/bayesflow/experimental/types.py b/bayesflow/experimental/types.py index bcceafd6..b58f9a08 100644 --- a/bayesflow/experimental/types.py +++ b/bayesflow/experimental/types.py @@ -1,7 +1,3 @@ - -from typing import Protocol, runtime_checkable - - Shape = tuple[int, ...] # this is ugly, but: @@ -17,12 +13,3 @@ except ModuleNotFoundError: import torch Tensor: type(torch.Tensor) = torch.Tensor - - -@runtime_checkable -class Distribution(Protocol): - def sample(self, *args, **kwargs): - raise NotImplementedError - - def log_prob(self, *args, **kwargs): - raise NotImplementedError diff --git a/bayesflow/experimental/utils.py b/bayesflow/experimental/utils.py deleted file mode 100644 index cfe7daaa..00000000 --- a/bayesflow/experimental/utils.py +++ /dev/null @@ -1,50 +0,0 @@ - -import keras - -from bayesflow.experimental.types import Tensor - - -def nested_getitem(data: dict, item: int) -> dict: - """ Get the item-th element from a nested dictionary """ - result = {} - for key, value in data.items(): - if isinstance(value, dict): - result[key] = nested_getitem(value, item) - else: - result[key] = value[item] - return result - - -def nested_merge(a: dict, b: dict) -> dict: - """ Merge a nested dictionary A into another nested dictionary B """ - for key, value in a.items(): - if isinstance(value, dict): - b[key] = nested_merge(value, b.get(key, {})) - else: - b[key] = value - return b - - -def apply_nested(fn: callable, data: dict) -> dict: - """ Apply a function to all non-dictionaries in a nested dictionary """ - result = {} - for key, value in data.items(): - if isinstance(value, dict): - result[key] = apply_nested(fn, value) - else: - # TODO: consuming version? this is not memory efficient - result[key] = fn(value) - - return result - - -def expand_left(tensor: Tensor, n: int) -> Tensor: - """ Expand a tensor to the left n times """ - idx = [None] * n + [slice(None)] * keras.ops.ndim(tensor) - return tensor[idx] - - -def expand_right(tensor: Tensor, n: int) -> Tensor: - """ Expand a tensor to the right n times """ - idx = [slice(None)] * keras.ops.ndim(tensor) + [None] * n - return tensor[idx] diff --git a/bayesflow/experimental/utils/__init__.py b/bayesflow/experimental/utils/__init__.py new file mode 100644 index 00000000..d36419ba --- /dev/null +++ b/bayesflow/experimental/utils/__init__.py @@ -0,0 +1,3 @@ + +from .dictutils import nested_getitem +from .finders import find_distribution, find_network diff --git a/bayesflow/experimental/utils/dictutils.py b/bayesflow/experimental/utils/dictutils.py new file mode 100644 index 00000000..605db752 --- /dev/null +++ b/bayesflow/experimental/utils/dictutils.py @@ -0,0 +1,10 @@ + +def nested_getitem(data: dict, item: int) -> dict: + """ Get the item-th element from a nested dictionary """ + result = {} + for key, value in data.items(): + if isinstance(value, dict): + result[key] = nested_getitem(value, item) + else: + result[key] = value[item] + return result diff --git a/bayesflow/experimental/utils/finders.py b/bayesflow/experimental/utils/finders.py new file mode 100644 index 00000000..9e97fa30 --- /dev/null +++ b/bayesflow/experimental/utils/finders.py @@ -0,0 +1,42 @@ +from typing import Callable + +import keras + +import bayesflow.experimental.distributions as D +import bayesflow.experimental.networks as N + + +def find_distribution(distribution: str | D.Distribution | Callable, **kwargs) -> D.Distribution: + match distribution: + case str() as name: + match name.lower(): + case "normal": + distribution = D.Normal() + case other: + raise NotImplementedError(f"Unsupported distribution name: '{other}'.") + case D.Distribution() as distribution: + pass + case Callable() as constructor: + distribution = constructor(**kwargs) + case other: + raise TypeError(f"Cannot infer distribution from {other!r}.") + + return distribution + + +def find_network(network: str | keras.Layer | Callable, **kwargs) -> keras.Layer: + match network: + case str() as name: + match name.lower(): + case "resnet": + network = N.ResNet(**kwargs) + case other: + raise NotImplementedError(f"Unsupported network name: '{other}'.") + case keras.Layer() as network: + pass + case Callable() as constructor: + network = constructor(**kwargs) + case other: + raise TypeError(f"Cannot infer network from {other!r}.") + + return network diff --git a/environment.yaml b/environment.yaml index 7bc0efad..01c49fd9 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,5 +1,4 @@ channels: - - defaults - conda-forge dependencies: - jupyter diff --git a/tox.ini b/tox.ini index 1d9ded2e..f7704868 100644 --- a/tox.ini +++ b/tox.ini @@ -35,7 +35,11 @@ set_env = PIP_EXTRA_INDEX_URL = https://download.pytorch.org/whl/cpu commands = - jax: python -m pytest tests/ -n auto -v -m jax - numpy: python -m pytest tests/ -n auto -v -m numpy - tensorflow: python -m pytest tests/ -n auto -v -m tensorflow - torch: python -m pytest tests/ -n auto -v -m torch + jax: + python -m pytest tests/ -n auto -v -m "not (numpy or tensorflow or torch)" + numpy: + python -m pytest tests/ -n auto -v -m "not (jax or tensorflow or torch)" + tensorflow: + python -m pytest tests/ -n auto -v -m "not (jax or numpy or torch)" + torch: + python -m pytest tests/ -n auto -v -m "not (jax or numpy or tensorflow)"