Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIF Implementation #182

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cif import CIF
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .flow_matching import FlowMatching
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/cif/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cif import CIF
91 changes: 91 additions & 0 deletions bayesflow/networks/cif/cif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import keras
from keras.saving import register_keras_serializable
from ..inference_network import InferenceNetwork
from ..coupling_flow import CouplingFlow
from .conditional_gaussian import ConditionalGaussian


@register_keras_serializable(package="bayesflow.networks")
class CIF(InferenceNetwork):
"""Implements a continuously indexed flow (CIF) with a `CouplingFlow`
bijection and `ConditionalGaussian` distributions p and q. Improves on
eliminating leaky sampling found topologically in normalizing flows.
Bulit in reference to [1].

[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising
Flows.
arXiv:1909.13833.
"""

def __init__(self, pq_depth=4, pq_width=128, pq_activation="tanh", **kwargs):
"""Creates an instance of a `CIF` with configurable
`ConditionalGaussian` distributions p and q, each containing MLP
networks

Parameters:
-----------
pq_depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
pq_width: int, optional, default: 128
The dimensionality of the MLP hidden layers
pq_activation: str, optional, default: 'tanh'
The MLP activation function
"""

super().__init__(base_distribution="normal", **kwargs)
self.bijection = CouplingFlow()
self.p_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)
self.q_dist = ConditionalGaussian(depth=pq_depth, width=pq_width, activation=pq_activation)

def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.bijection.build(xz_shape, conditions_shape=conditions_shape)
self.p_dist.build(xz_shape)
self.q_dist.build(xz_shape)

def call(self, xz, conditions=None, inverse=False, **kwargs):
if inverse:
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)

def _forward(self, x, conditions=None, density=False, **kwargs):
# Sample u ~ q_u
u, log_qu = self.q_dist.sample(x, log_prob=True)

# Bijection and log jacobian x -> z
z, log_jac = self.bijection(x, conditions=conditions, density=True)
if log_jac.ndim > 1:
log_jac = keras.ops.sum(log_jac, axis=1)

# Log prob over p on u with conditions z
log_pu = self.p_dist.log_prob(u, z)

# Prior log prob
log_prior = self.base_distribution.log_prob(z)
if log_prior.ndim > 1:
log_prior = keras.ops.sum(log_prior, axis=1)

# ELBO loss
elbo = log_jac + log_pu + log_prior - log_qu

if density:
return z, elbo
return z

def _inverse(self, z, conditions=None, density=False, **kwargs):
# Inverse bijection z -> x
u = self.p_dist.sample(z)
x = self.bijection(z, conditions=conditions, inverse=True)
if density:
log_pu = self.p_dist.log_prob(u, x)
return x, log_pu
return x

def compute_metrics(self, data, stage="training"):
base_metrics = super().compute_metrics(data, stage=stage)
inference_variables = data["inference_variables"]
inference_conditions = data.get("inference_conditions")
_, elbo = self(inference_variables, conditions=inference_conditions, inverse=False, density=True)
loss = -keras.ops.mean(elbo)
return base_metrics | {"loss": loss}
73 changes: 73 additions & 0 deletions bayesflow/networks/cif/conditional_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import keras
from keras.saving import register_keras_serializable
import numpy as np
from ..mlp import MLP
from bayesflow.utils import keras_kwargs


@register_keras_serializable(package="bayesflow.networks.cif")
class ConditionalGaussian(keras.Layer):
"""Implements a conditional gaussian distribution with neural networks for
the means and standard deviations respectively. Bulit in reference to [1].

[1] R. Cornish, A. Caterini, G. Deligiannidis, & A. Doucet (2021).
Relaxing Bijectivity Constraints with Continuously Indexed Normalising
Flows.
arXiv:1909.13833.
"""

def __init__(self, depth=4, width=128, activation="tanh", **kwargs):
"""Creates an instance of a `ConditionalGaussian` with configurable
`MLP` networks for the means and standard deviations.

Parameters:
-----------
depth: int, optional, default: 4
The number of MLP hidden layers (minimum: 1)
width: int, optional, default: 128
The dimensionality of the MLP hidden layers
activation: str, optional, default: "tanh"
The MLP activation function
"""

super().__init__(**keras_kwargs(kwargs))
self.means = MLP(depth=depth, width=width, activation=activation)
self.stds = MLP(depth=depth, width=width, activation=activation)
self.output_projector = keras.layers.Dense(None)

def build(self, input_shape):
self.means.build(input_shape)
self.stds.build(input_shape)
self.output_projector.units = input_shape[-1]

def _diagonal_gaussian_log_prob(self, conditions, means, stds):
flat_c = keras.layers.Flatten()(conditions)
flat_means = keras.layers.Flatten()(means)
flat_vars = keras.layers.Flatten()(stds) ** 2

dim = keras.ops.shape(flat_c)[1]

const_term = -0.5 * dim * np.log(2 * np.pi)
log_det_terms = -0.5 * keras.ops.sum(keras.ops.log(flat_vars), axis=1)
product_terms = -0.5 * keras.ops.sum((flat_c - flat_means) ** 2 / flat_vars, axis=1)

return const_term + log_det_terms + product_terms

def log_prob(self, x, conditions):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))
return self._diagonal_gaussian_log_prob(x, means, stds)

def sample(self, conditions, log_prob=False):
means = self.output_projector(self.means(conditions))
stds = keras.ops.exp(self.output_projector(self.stds(conditions)))

# Reparameterize
samples = stds * keras.random.normal(keras.ops.shape(conditions)) + means

# Log probability
if log_prob:
log_p = self._diagonal_gaussian_log_prob(samples, means, stds)
return samples, log_p

return samples
258 changes: 258 additions & 0 deletions examples/moons_cif.ipynb

Large diffs are not rendered by default.

Loading