diff --git a/bayesflow/experimental/datasets/online_dataset.py b/bayesflow/experimental/datasets/online_dataset.py index a9b26eab..a4dd830b 100644 --- a/bayesflow/experimental/datasets/online_dataset.py +++ b/bayesflow/experimental/datasets/online_dataset.py @@ -1,7 +1,7 @@ import keras -from bayesflow.experimental.simulation.distributions import JointDistribution +from bayesflow.experimental.simulation import JointDistribution class OnlineDataset(keras.utils.PyDataset): diff --git a/bayesflow/experimental/datasets/rounds_dataset.py b/bayesflow/experimental/datasets/rounds_dataset.py index 2aac51ac..86694b65 100644 --- a/bayesflow/experimental/datasets/rounds_dataset.py +++ b/bayesflow/experimental/datasets/rounds_dataset.py @@ -1,7 +1,7 @@ import keras -from bayesflow.experimental.simulation.distributions.joint_distribution import JointDistribution +from bayesflow.experimental.simulation import JointDistribution class RoundsDataset(keras.utils.PyDataset): diff --git a/bayesflow/experimental/distributions/__init__.py b/bayesflow/experimental/distributions/__init__.py index 580d84e4..b05ff1b4 100644 --- a/bayesflow/experimental/distributions/__init__.py +++ b/bayesflow/experimental/distributions/__init__.py @@ -1,3 +1,20 @@ from .distribution import Distribution from .normal import Normal + + +def find_distribution(distribution: str | Distribution | type(Distribution)) -> Distribution: + if isinstance(distribution, Distribution): + return distribution + if isinstance(distribution, type): + return Distribution() + + match distribution: + case "normal": + distribution = Normal() + case str() as unknown_distribution: + raise ValueError(f"Distribution '{unknown_distribution}' is unknown or not yet supported by name.") + case other: + raise TypeError(f"Unknown distribution type: {other}") + + return distribution diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index 66fe2de4..9172101e 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -1,25 +1,21 @@ -from typing import Sequence +from typing import Tuple, Union import keras from keras.saving import ( - deserialize_keras_object, register_keras_serializable, - serialize_keras_object, ) -from bayesflow.experimental.types import Shape, Tensor +from bayesflow.experimental.types import Tensor from .actnorm import ActNorm from .couplings import DualCoupling -from .invertible_layer import InvertibleLayer - -from ...simulation.distributions import find_distribution +from ..inference_network import InferenceNetwork @register_keras_serializable(package="bayesflow.networks") -class CouplingFlow(keras.Model, InvertibleLayer): +class CouplingFlow(InferenceNetwork): """ Implements a coupling flow as a sequence of dual couplings with permutations and activation - normalization. Incorporates ideas from [1-4]. + normalization. Incorporates ideas from [1-5]. [1] Kingma, D. P., & Dhariwal, P. (2018). Glow: Generative flow with invertible 1x1 convolutions. @@ -40,92 +36,55 @@ class CouplingFlow(keras.Model, InvertibleLayer): Robust model training and generalisation with Studentising flows. arXiv preprint arXiv:2006.06599. """ - def __init__(self, invertible_layers: Sequence[InvertibleLayer], base_distribution: str = "normal", **kwargs): - super().__init__(**kwargs) - self.invertible_layers = list(invertible_layers) - self.base_distribution = find_distribution(base_distribution) - - @classmethod - def new( - cls, - depth: int = 6, - subnet: str = "resnet", - transform: str = "affine", - base_distribution: str = "normal", - use_actnorm: bool = True, - **kwargs + def __init__( + self, + depth: int = 6, + subnet: str = "resnet", + transform: str = "affine", + use_actnorm: bool = True, **kwargs ): + super().__init__(**kwargs) - layers = [] - for i in range(depth): + self._layers = [] + for _ in range(depth): if use_actnorm: - layers.append(ActNorm()) - layers.append(DualCoupling.new(subnet, transform)) - - return cls(layers, base_distribution, **kwargs) - - @classmethod - def from_config(cls, config, custom_objects=None): - couplings = deserialize_keras_object(config.pop("invertible_layers")) - base_distribution = deserialize_keras_object(config.pop("base_distribution")) - - return cls(couplings, base_distribution, **config) - - def get_config(self): - base_config = super().get_config() - - config = { - "invertible_layers": serialize_keras_object(self.invertible_layers), - "base_distribution": serialize_keras_object(self.base_distribution), - } - - return base_config | config + self._layers.append(ActNorm()) + self._layers.append(DualCoupling(subnet, transform)) def build(self, input_shape): - # nothing to do here, since we do not know the conditions yet - self.base_distribution.build(input_shape) + super().build(input_shape) + self.call(keras.KerasTensor(input_shape)) - def call(self, x: Tensor, conditions: any = None, inverse: bool = False) -> (Tensor, Tensor): + def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: if inverse: - return self._inverse(x, conditions) - return self._forward(x, conditions) + return self._inverse(xz, **kwargs) + return self._forward(xz, **kwargs) - def _forward(self, x: Tensor, conditions: any = None) -> (Tensor, Tensor): + def _forward(self, x: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: z = x log_det = 0.0 - for layer in self.invertible_layers: - z, det = layer(z, conditions=conditions) + for layer in self._layers: + z, det = layer(z, inverse=False, **kwargs) log_det += det - return z, log_det + if jacobian: + return z, log_det + return z - def _inverse(self, z: Tensor, conditions: any = None) -> (Tensor, Tensor): + def _inverse(self, z: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = z log_det = 0.0 - for layer in reversed(self.invertible_layers): - x, det = layer(x, conditions=conditions, inverse=True) + for layer in reversed(self._layers): + x, det = layer(x, inverse=True, **kwargs) log_det += det - return x, log_det - - def sample(self, batch_shape: Shape, conditions=None) -> Tensor: - z = self.base_distribution.sample(batch_shape) - x, _ = self(z, conditions, inverse=True) - + if jacobian: + return x, log_det return x - def log_prob(self, x: Tensor, conditions=None) -> Tensor: - z, log_det = self(x, conditions) - log_prob = self.base_distribution.log_prob(z) - - return log_prob + log_det - - def compute_loss(self, x=None, y=None, y_pred=None, **kwargs): - z, log_det = y_pred + def compute_loss(self, x: Tensor = None, **kwargs): + z, log_det = self(x, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(z) nll = -keras.ops.mean(log_prob + log_det) return nll - - def compute_metrics(self, x, y, y_pred, **kwargs): - return {} diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py index 0896a74a..f58b1d4e 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py @@ -1,9 +1,7 @@ import keras from keras.saving import ( - deserialize_keras_object, register_keras_serializable, - serialize_keras_object, ) from bayesflow.experimental.types import Tensor @@ -13,51 +11,15 @@ @register_keras_serializable(package="bayesflow.networks.coupling_flow") class DualCoupling(InvertibleLayer): - def __init__(self, coupling1: SingleCoupling, coupling2: SingleCoupling, pivot: int = None, **kwargs): - super().__init__(**kwargs) - self.coupling1 = coupling1 - self.coupling2 = coupling2 - self.pivot = pivot - - @classmethod - def new(cls, *args, **kwargs) -> "DualCoupling": - """ Construct a new DualCoupling from hyperparameters. """ - coupling1 = SingleCoupling.new(*args, **kwargs) - coupling2 = SingleCoupling.new(*args, **kwargs) - - return cls(coupling1, coupling2, **kwargs) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "DualCoupling": - coupling1 = deserialize_keras_object(config.pop("coupling1")) - coupling2 = deserialize_keras_object(config.pop("coupling2")) - pivot = config.pop("pivot") - - return cls(coupling1, coupling2, pivot=pivot, **config) - - def get_config(self) -> dict: - base_config = super().get_config() - - config = { - "coupling1": serialize_keras_object(self.coupling1), - "coupling2": serialize_keras_object(self.coupling2), - "pivot": self.pivot, - } - - return base_config | config + def __init__(self, subnet: str = "resnet", transform: str = "affine"): + super().__init__() + self.coupling1 = SingleCoupling(subnet, transform) + self.coupling2 = SingleCoupling(subnet, transform) + self.pivot = None def build(self, input_shape): self.pivot = input_shape[-1] // 2 - x1_shape = list(input_shape) - x2_shape = list(input_shape) - - x1_shape[-1] = self.pivot - x2_shape[-1] = input_shape[-1] - self.pivot - - self.coupling1.build((x1_shape, x2_shape)) - self.coupling2.build((x2_shape, x1_shape)) - def call(self, xz: Tensor, conditions: any = None, inverse: bool = False) -> (Tensor, Tensor): if inverse: return self._inverse(xz, conditions=conditions) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py index 13b0fea0..f40d92cf 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -1,57 +1,30 @@ import keras from keras.saving import ( - deserialize_keras_object, - register_keras_serializable, - serialize_keras_object + register_keras_serializable ) from bayesflow.experimental.types import Tensor from ..invertible_layer import InvertibleLayer from ..subnets import find_subnet -from ..transforms import find_transform, Transform +from ..transforms import find_transform @register_keras_serializable(package="bayesflow.networks.coupling_flow") class SingleCoupling(InvertibleLayer): - def __init__(self, subnet: keras.Layer, transform: Transform, **kwargs): - super().__init__(**kwargs) - self.subnet = subnet - self.transform = transform - - @classmethod - def new(cls, subnet: str = "resnet", transform: str = "affine", **kwargs) -> "SingleCoupling": - transform = find_transform(transform) - subnet = find_subnet(subnet) - - return cls(subnet, transform, **kwargs) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "SingleCoupling": - subnet = deserialize_keras_object(config.pop("subnet")) - transform = deserialize_keras_object(config.pop("transform")) - - return cls(subnet, transform, **config) + """ + Implements a single coupling layer as a composition of a subnet and a transform. - def get_config(self) -> dict: - base_config = super().get_config() - - config = { - "subnet": serialize_keras_object(self.subnet), - "transform": serialize_keras_object(self.transform), - } - - return base_config | config + Subnet output tensors are linearly mapped to the correct dimension. + """ + def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(None) + self.subnet = find_subnet(subnet) + self.transform = find_transform(transform) def build(self, input_shape): - # TODO: this is not ideal... - if not hasattr(self.subnet, "build_output"): - return - - x1_shape, x2_shape = input_shape - x2_shape = list(x2_shape) - x2_shape[-1] = self.transform.params_per_dim * x2_shape[-1] - self.subnet.build_output(x2_shape) + self.dense.units = self.transform.params_per_dim * input_shape[-1] def call(self, x1: Tensor, x2: Tensor, conditions: any = None, inverse: bool = False) -> ((Tensor, Tensor), Tensor): if inverse: @@ -79,7 +52,7 @@ def get_parameters(self, x, conditions: any = None) -> dict[str, Tensor]: if keras.ops.is_tensor(conditions): x = keras.ops.concatenate([x, conditions], axis=-1) - parameters = self.subnet(x) + parameters = self.dense(self.subnet(x)) parameters = self.transform.split_parameters(parameters) parameters = self.transform.constrain_parameters(parameters) diff --git a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py b/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py index ede5381f..4dbee6d6 100644 --- a/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py +++ b/bayesflow/experimental/networks/coupling_flow/subnets/__init__.py @@ -12,9 +12,7 @@ def find_subnet(subnet: str | keras.Layer | Callable, **kwargs) -> keras.Layer: case str() as name: match name.lower(): case "resnet": - resnet = networks.ResNet.new(**kwargs) - resnet.output_layer.units = None - return resnet + return networks.ResNet(**kwargs) case other: raise NotImplementedError(f"Unsupported subnet name: '{other}'.") case keras.Layer() as layer: diff --git a/bayesflow/experimental/networks/flow_matching/flow_matching.py b/bayesflow/experimental/networks/flow_matching/flow_matching.py index 5a26da4e..71c81ded 100644 --- a/bayesflow/experimental/networks/flow_matching/flow_matching.py +++ b/bayesflow/experimental/networks/flow_matching/flow_matching.py @@ -42,14 +42,6 @@ def get_config(self) -> dict: def build(self, input_shape): self.network.build(input_shape) - def train_step(self, data): - # hack to avoid the call method in super().train_step() - # maybe you have a better idea? Seems the train_step is not backend-agnostic since it requires gradient tracking - call = self.call - self.call = lambda *args, **kwargs: None - super().train_step(data) - self.call = call - def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]: # implement conditions = None and jacobian = False first # then work your way up @@ -62,7 +54,17 @@ def compute_loss(self, x=None, **kwargs): # x should ideally contain both x0 and x1, # where the optimal transport matching already happened in the worker process # this is possible, but might not be super user-friendly. We will have to see. - raise NotImplementedError + x0, x1, t = x + + xt = t * x1 + (1 - t) * x0 + + # get velocity at xt + v = ... + + # target velocity: + vstar = x1 - x0 + + # return mse between v and vstar # TODO: see below for reference implementation diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index 44865d5f..ca9e4825 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -4,9 +4,9 @@ import keras from keras.saving import ( register_keras_serializable, - serialize_keras_object, ) +from bayesflow.experimental.distributions import find_distribution from bayesflow.experimental.types import Tensor @@ -14,14 +14,11 @@ class InferenceNetwork(keras.Model): def __init__(self, base_distribution: str = "normal", **kwargs): super().__init__(**kwargs) - # TODO: get the actual distribution object from the string representation - self.base_distribution = base_distribution + self.base_distribution = find_distribution(base_distribution) - def get_config(self) -> dict: - base_config = super().get_config() - # TODO: get the string representation of the distribution object - config = {"base_distribution": serialize_keras_object(self.base_distribution)} - return base_config | config + def build(self, input_shape): + super().build(input_shape) + self.base_distribution.build(input_shape) def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: if inverse: @@ -42,3 +39,12 @@ def log_prob(self, x: Tensor, **kwargs) -> Tensor: samples, log_det = self(x, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(samples) return log_prob + log_det + + def train_step(self, data): + # hack to avoid the call method in super().train_step() + call = self.call + self.call = lambda *args, **kwargs: None + rv = super().train_step(data) + self.call = call + + return rv diff --git a/bayesflow/experimental/networks/resnet/resnet.py b/bayesflow/experimental/networks/resnet/resnet.py index 504c6dc1..76be5405 100644 --- a/bayesflow/experimental/networks/resnet/resnet.py +++ b/bayesflow/experimental/networks/resnet/resnet.py @@ -1,11 +1,6 @@ - -from typing import Sequence - import keras from keras.saving import ( - deserialize_keras_object, register_keras_serializable, - serialize_keras_object, ) from bayesflow.experimental.types import Tensor @@ -15,55 +10,19 @@ class ResNet(keras.Layer): """ Implements a simple, fully-connected residual network. - Input tensors are linearly mapped to the width of the network to ensure shape-compatibility. + Input tensors are linearly mapped to the correct dimension. """ - def __init__(self, input_layer: keras.Layer, hidden_layers: Sequence[keras.Layer], output_layer: keras.Layer, **kwargs): + def __init__(self, depth: int = 6, width: int = 256, activation: str = "relu", **kwargs): super().__init__(**kwargs) - self.input_layer = input_layer - self.hidden_layers = list(hidden_layers) - self.output_layer = output_layer - - @classmethod - def new(cls, depth: int = 4, width: int = 256, activation: str = "relu", **kwargs) -> "ResNet": - input_layer = keras.layers.Dense(width) - hidden_layers = [keras.layers.Dense(width, activation=activation) for _ in range(depth)] - output_layer = keras.layers.Dense(width) - - return cls(input_layer, hidden_layers, output_layer, **kwargs) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "ResNet": - input_layer = deserialize_keras_object(config.pop("input_layer")) - hidden_layers = deserialize_keras_object(config.pop("hidden_layers")) - output_layer = deserialize_keras_object(config.pop("output_layer")) - - return cls(input_layer, hidden_layers, output_layer, **config) - - def get_config(self) -> dict: - base_config = super().get_config() - - config = { - "input_layer": serialize_keras_object(self.input_layer), - "hidden_layers": serialize_keras_object(self.hidden_layers), - "output_layer": serialize_keras_object(self.output_layer), - } - - return base_config | config - - def build(self, input_shape): - self.call(keras.KerasTensor(input_shape)) - def build_output(self, output_shape): - match self.output_layer: - case keras.layers.Dense() as dense: - dense.units = output_shape[-1] - case other: - raise NotImplementedError(f"Cannot build output for layer {other!r}") + self.input_layer = keras.layers.Dense(width) + self.hidden_layers = [keras.layers.Dense(width, activation=activation) for _ in range(depth - 1)] + self.output_layer = keras.layers.Dense(width, activation=activation) def call(self, x: Tensor) -> Tensor: x = self.input_layer(x) for layer in self.hidden_layers: x = x + layer(x) - x = self.output_layer(x) + x = x + self.output_layer(x) return x diff --git a/bayesflow/experimental/simulation/__init__.py b/bayesflow/experimental/simulation/__init__.py index cfb1d6e2..168bf8e4 100644 --- a/bayesflow/experimental/simulation/__init__.py +++ b/bayesflow/experimental/simulation/__init__.py @@ -1,3 +1,3 @@ from .decorators import DistributionDecorator as distribution -from .distributions import Distribution, find_distribution, JointDistribution +from .distributions import JointDistribution diff --git a/bayesflow/experimental/simulation/decorators/distribution_decorator.py b/bayesflow/experimental/simulation/decorators/distribution_decorator.py index 2296efa0..6737ebca 100644 --- a/bayesflow/experimental/simulation/decorators/distribution_decorator.py +++ b/bayesflow/experimental/simulation/decorators/distribution_decorator.py @@ -1,14 +1,9 @@ import functools -import inspect -import keras - -from functools import wraps -from bayesflow.experimental.types import Shape -from bayesflow.experimental.simulation.distributions import Distribution +import keras -from bayesflow.experimental import utils +from bayesflow.experimental.distributions import Distribution class DistributionDecorator: diff --git a/bayesflow/experimental/simulation/distributions/__init__.py b/bayesflow/experimental/simulation/distributions/__init__.py index 13bdf3d5..df2ac477 100644 --- a/bayesflow/experimental/simulation/distributions/__init__.py +++ b/bayesflow/experimental/simulation/distributions/__init__.py @@ -1,21 +1 @@ - -from bayesflow.experimental.types import Distribution, Shape - from .joint_distribution import JointDistribution -from .spherical_gaussian import SphericalGaussian - - -def find_distribution(distribution: str | Distribution | type(Distribution)) -> Distribution: - if isinstance(distribution, Distribution): - return distribution - if isinstance(distribution, type): - return Distribution() - match distribution: - case "normal": - distribution = SphericalGaussian() - case str() as unknown_distribution: - raise ValueError(f"Distribution '{unknown_distribution}' is unknown or not yet supported by name.") - case other: - raise TypeError(f"Unknown distribution type: {other}") - - return distribution diff --git a/bayesflow/experimental/simulation/distributions/joint_distribution.py b/bayesflow/experimental/simulation/distributions/joint_distribution.py index f6d64c02..77f5b8b9 100644 --- a/bayesflow/experimental/simulation/distributions/joint_distribution.py +++ b/bayesflow/experimental/simulation/distributions/joint_distribution.py @@ -4,8 +4,8 @@ #TODO - Make the distribution a layer! class JointDistribution: def __init__(self, prior, likelihood): - self.prior = prior - self.likelihood = likelihood + self.prior = utils.make_distribution(prior) + self.likelihood = utils.make_distribution(likelihood) def sample(self, batch_shape: Shape) -> dict: parameters = self.prior.sample(batch_shape) diff --git a/bayesflow/experimental/simulation/distributions/spherical_gaussian.py b/bayesflow/experimental/simulation/distributions/spherical_gaussian.py deleted file mode 100644 index 59415e6e..00000000 --- a/bayesflow/experimental/simulation/distributions/spherical_gaussian.py +++ /dev/null @@ -1,34 +0,0 @@ - -import math - -import keras -from keras import ops - -from bayesflow.experimental.types import Shape, Distribution, Tensor - - -@keras.saving.register_keras_serializable(package="bayesflow.simulation") -class SphericalGaussian(Distribution): - """Utility class for a backend-agnostic spherical Gaussian distribution. - - Note: - - ``log_unnormalized_prob`` method is used as a loss function - - ``log_prob`` is used for density computation - """ - def __init__(self): - self.dim = None - self.log_norm_const = None - - def sample(self, batch_shape: Shape): - return keras.random.normal(shape=batch_shape + (self.dim,), mean=0.0, stddev=1.0) - - def log_unnormalized_prob(self, tensor: Tensor): - return -0.5 * ops.sum(ops.square(tensor), axis=-1) - - def log_prob(self, tensor: Tensor): - log_unnorm_pdf = self.log_unnormalized_prob(tensor) - return log_unnorm_pdf - self.log_norm_const - - def build(self, input_shape): - self.dim = int(input_shape[-1]) - self.log_norm_const = 0.5 * self.dim * math.log(2.0 * math.pi) diff --git a/tests/test_keras.py b/tests/test_keras.py deleted file mode 100644 index 2b66bba9..00000000 --- a/tests/test_keras.py +++ /dev/null @@ -1,10 +0,0 @@ - -import pytest - - -def test_import(): - import keras - - -def test_py_dataset_exists(): - from keras.utils import PyDataset diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index 60e1ace2..04e41b8c 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -11,16 +11,16 @@ def batch_size(request): @pytest.fixture() def coupling_flow(): from bayesflow.experimental.networks import CouplingFlow - return CouplingFlow.new() + return CouplingFlow() @pytest.fixture() def flow_matching(): from bayesflow.experimental.networks import FlowMatching - return FlowMatching.new() + return FlowMatching() -@pytest.fixture(params=["coupling_flow"]) +@pytest.fixture(params=["coupling_flow", "flow_matching"]) def inference_network(request): return request.getfixturevalue(request.param) @@ -37,7 +37,7 @@ def num_features(request): @pytest.fixture() def random_samples(batch_size, num_features): - return keras.random.normal() + return keras.random.normal((batch_size, num_features)) @pytest.fixture() @@ -48,7 +48,7 @@ def random_set(batch_size, set_size, num_features): @pytest.fixture() def resnet(): from bayesflow.experimental.networks import ResNet - return ResNet.new() + return ResNet() @pytest.fixture(params=[2, 3]) diff --git a/tests/test_networks/test_coupling_flow/conftest.py b/tests/test_networks/test_coupling_flow/conftest.py index 51d3d4f7..2a1a398b 100644 --- a/tests/test_networks/test_coupling_flow/conftest.py +++ b/tests/test_networks/test_coupling_flow/conftest.py @@ -17,7 +17,7 @@ def batch_size(request): @pytest.fixture() def dual_coupling(): from bayesflow.experimental.networks.coupling_flow.couplings import DualCoupling - return DualCoupling.new() + return DualCoupling() @pytest.fixture(params=["actnorm", "dual_coupling"]) @@ -38,4 +38,4 @@ def random_input(batch_size, num_features): @pytest.fixture() def single_coupling(): from bayesflow.experimental.networks.coupling_flow.couplings import SingleCoupling - return SingleCoupling.new() + return SingleCoupling() diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index f9a6aaa5..0ecfcb63 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -35,8 +35,8 @@ def test_variable_batch_size(inference_network, random_samples): inference_network(new_input, inverse=True) -def test_output_structure(invertible_layer, random_input): - output = invertible_layer(random_input) +def test_output_structure(inference_network, random_input): + output = inference_network(random_input) assert isinstance(output, tuple) assert len(output) == 2 @@ -47,13 +47,13 @@ def test_output_structure(invertible_layer, random_input): assert keras.ops.is_tensor(forward_log_det) -def test_output_shape(invertible_layer, random_input): - forward_output, forward_log_det = invertible_layer(random_input) +def test_output_shape(inference_network, random_input): + forward_output, forward_log_det = inference_network(random_input) assert keras.ops.shape(forward_output) == keras.ops.shape(random_input) assert keras.ops.shape(forward_log_det) == (keras.ops.shape(random_input)[0],) - inverse_output, inverse_log_det = invertible_layer(random_input, inverse=True) + inverse_output, inverse_log_det = inference_network(random_input, inverse=True) assert keras.ops.shape(inverse_output) == keras.ops.shape(random_input) assert keras.ops.shape(inverse_log_det) == (keras.ops.shape(random_input)[0],) @@ -69,11 +69,11 @@ def test_cycle_consistency(inference_network, random_samples): @pytest.mark.torch -def test_jacobian_numerically(invertible_layer, random_input): +def test_jacobian_numerically(inference_network, random_input): import torch - forward_output, forward_log_det = invertible_layer(random_input, jacobian=True) - numerical_forward_jacobian, _ = torch.autograd.functional.jacobian(invertible_layer, random_input, vectorize=True) + forward_output, forward_log_det = inference_network(random_input, jacobian=True) + numerical_forward_jacobian, _ = torch.autograd.functional.jacobian(inference_network, random_input, vectorize=True) # TODO: torch is somehow permuted wrt keras numerical_forward_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])] @@ -81,9 +81,9 @@ def test_jacobian_numerically(invertible_layer, random_input): assert keras.ops.all(keras.ops.isclose(forward_log_det, numerical_forward_log_det)) - inverse_output, inverse_log_det = invertible_layer(random_input, inverse=True, jacobian=True) + inverse_output, inverse_log_det = inference_network(random_input, inverse=True, jacobian=True) - numerical_inverse_jacobian, _ = torch.autograd.functional.jacobian(functools.partial(invertible_layer, inverse=True), random_input, vectorize=True) + numerical_inverse_jacobian, _ = torch.autograd.functional.jacobian(functools.partial(inference_network, inverse=True), random_input, vectorize=True) # TODO: torch is somehow permuted wrt keras numerical_inverse_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])]