Skip to content

Commit

Permalink
Merge pull request #203 from stefanradev93/distributions-student-t
Browse files Browse the repository at this point in the history
Implement DiagonalStudentT.sample method natively in keras
  • Loading branch information
paul-buerkner authored Oct 10, 2024
2 parents ee0e9f1 + 715c553 commit f27da60
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions bayesflow/distributions/diagonal_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import math
import numpy as np

from scipy.stats import t as scipy_student_t

from bayesflow.types import Shape, Tensor
from bayesflow.utils import expand_tile

from .distribution import Distribution


Expand All @@ -20,6 +21,7 @@ def __init__(
loc: int | float | np.ndarray | Tensor = 0.0,
scale: int | float | np.ndarray | Tensor = 1.0,
use_learnable_parameters: bool = False,
seed_generator: keras.random.SeedGenerator = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -33,6 +35,11 @@ def __init__(

self.use_learnable_parameters = use_learnable_parameters

if seed_generator is None:
seed_generator = keras.random.SeedGenerator()

self.seed_generator = seed_generator

def build(self, input_shape: Shape) -> None:
self.dim = int(input_shape[-1])

Expand Down Expand Up @@ -78,9 +85,15 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
return result

def sample(self, batch_shape: Shape) -> Tensor:
# TODO: use reparameterization trick instead of scipy
# TODO: use the seed generator state
dist = scipy_student_t(df=self.df, loc=self.loc, scale=self.scale)
samples = dist.rvs(size=batch_shape + (self.dim,))
# As of writing this code, keras does not support the chi-square distribution
# nor does it support a scale or rate parameter in Gamma. Hence, we use the relation:
# chi-square(df) = Gamma(shape = 0.5 * df, scale = 2) = Gamma(shape = 0.5 * df, scale = 1) * 2
chi2_samples = keras.random.gamma(batch_shape, alpha=0.5 * self.df, seed=self.seed_generator) * 2.0

# The chi-quare samples need to be repeated across self.dim
# since for each element of batch_shape only one sample is created.
chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1)

normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)

return keras.ops.convert_to_tensor(samples)
return self.loc + self.scale * normal_samples * keras.ops.sqrt(self.df / chi2_samples)

0 comments on commit f27da60

Please sign in to comment.