Skip to content

Commit

Permalink
Fix sampling and get rid of tensorflow_probability for default Gaussians
Browse files Browse the repository at this point in the history
  • Loading branch information
Radev committed May 24, 2024
1 parent 08e5b31 commit f566250
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def inverse(self, latents, conditions=None) -> (Tensor, Tensor):

return targets, log_det

def sample(self, batch_shape: Shape, conditions=None) -> Tensor:
def sample(self, batch_shape: Shape | int, conditions=None) -> Tensor:
if type(batch_shape) is int:
batch_shape = (batch_shape, )
latents = self.base_distribution.sample(batch_shape)
targets, _ = self.inverse(latents, conditions)

Expand Down
15 changes: 2 additions & 13 deletions bayesflow/experimental/simulation/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@

import keras
from bayesflow.experimental.types import Distribution, Shape

from .joint_distribution import JointDistribution
from .spherical_gaussian import SphericalGaussian


def find_distribution(distribution: str | Distribution | type(Distribution), shape: Shape) -> Distribution:
if isinstance(distribution, Distribution):
return distribution
if isinstance(distribution, type):
return Distribution()

match distribution:
case "normal":
match keras.backend.backend():
case "jax" | "tensorflow":
import tensorflow as tf
import tensorflow_probability as tfp
distribution = tfp.distributions.Normal(tf.zeros(shape), tf.ones(shape))
distribution = tfp.distributions.Independent(distribution, 1)
case "torch":
import torch
import torch.distributions as D
distribution = D.Normal(torch.zeros(shape), torch.ones(shape))
distribution = D.Independent(distribution, 1)
distribution = SphericalGaussian(shape)
case str() as unknown_distribution:
raise ValueError(f"Distribution '{unknown_distribution}' is unknown or not yet supported by name.")
case other:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

import math

import keras
from keras import ops

from bayesflow.experimental.types import Shape, Distribution, Tensor


class SphericalGaussian(Distribution):
"""Utility class for a backend-agnostic spherical Gaussian distribution.
Note:
- ``log_unnormalized_pdf`` method is used as a loss function
- ``log_pdf`` is used for density computation
"""
def __init__(self, shape: Shape):
self.shape = shape
self.dim = int(self.shape[0])
self._norm_const = 0.5 * self.dim * math.log(2.0 * math.pi)

def sample(self, batch_shape: Shape):
return keras.random.normal(shape=batch_shape + self.shape, mean=0.0, stddev=1.0)

def log_unnormalized_prob(self, tensor: Tensor):
return -0.5 * ops.sum(ops.square(tensor), axis=-1)

def log_prob(self, tensor: Tensor):
log_unnorm_pdf = self.log_unnormalized_prob(tensor)
return log_unnorm_pdf - self._norm_const

0 comments on commit f566250

Please sign in to comment.