From 8010296ac48e4294607ecbf0b2a054822c6b4dbf Mon Sep 17 00:00:00 2001 From: lars Date: Wed, 5 Jun 2024 11:33:21 +0200 Subject: [PATCH 01/24] fix variable batch size test for conditions != None --- tests/test_networks/test_inference_networks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index a615fb93..4340e758 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -25,14 +25,14 @@ def test_variable_batch_size(inference_network, random_samples, random_condition # run with another batch size batch_sizes = np.random.choice(10, replace=False, size=3) - for batch_size in batch_sizes: - new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:]) + for bs in batch_sizes: + new_input = keras.ops.zeros((bs,) + keras.ops.shape(random_samples)[1:]) if random_conditions is None: new_conditions = None else: - new_conditions = keras.ops.zeros((batch_size,), + keras.ops.shape(random_conditions)[1:]) + new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:]) - inference_network(new_input) + inference_network(new_input, conditions=new_conditions) inference_network(new_input, conditions=new_conditions, inverse=True) From cb554988af95f2cf932e8aaf76aa9f9c6100b2fb Mon Sep 17 00:00:00 2001 From: lars Date: Wed, 5 Jun 2024 16:48:51 +0200 Subject: [PATCH 02/24] clean up amortizers --- bayesflow/experimental/amortizers/__init__.py | 3 --- .../amortizers/amortized_likelihood.py | 23 ------------------- .../amortizers/amortized_posterior.py | 21 ----------------- .../amortizers/point_amortizer.py | 20 ---------------- 4 files changed, 67 deletions(-) delete mode 100644 bayesflow/experimental/amortizers/amortized_likelihood.py delete mode 100644 bayesflow/experimental/amortizers/amortized_posterior.py delete mode 100644 bayesflow/experimental/amortizers/point_amortizer.py diff --git a/bayesflow/experimental/amortizers/__init__.py b/bayesflow/experimental/amortizers/__init__.py index 9d1e36e7..6d997d8e 100644 --- a/bayesflow/experimental/amortizers/__init__.py +++ b/bayesflow/experimental/amortizers/__init__.py @@ -1,6 +1,3 @@ -from .amortized_likelihood import AmortizedLikelihood -from .amortized_posterior import AmortizedPosterior from .amortizer import Amortizer from .joint_amortizer import JointAmortizer -from .point_amortizer import AmortizedPointEstimator diff --git a/bayesflow/experimental/amortizers/amortized_likelihood.py b/bayesflow/experimental/amortizers/amortized_likelihood.py deleted file mode 100644 index 0ed8baa9..00000000 --- a/bayesflow/experimental/amortizers/amortized_likelihood.py +++ /dev/null @@ -1,23 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class AmortizedLikelihood(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=-1) - - def configure_observed_variables(self, data: dict): - # TODO: concatenate local context - return keras.ops.concatenate(list(data["parameters"].values()), axis=-1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - # TODO: concatenate global context - if summary_outputs is not None: - return summary_outputs - - return self.configure_observed_variables(data) - - def configure_summary_conditions(self, data: dict): - return None diff --git a/bayesflow/experimental/amortizers/amortized_posterior.py b/bayesflow/experimental/amortizers/amortized_posterior.py deleted file mode 100644 index 5fca58b1..00000000 --- a/bayesflow/experimental/amortizers/amortized_posterior.py +++ /dev/null @@ -1,21 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class AmortizedPosterior(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["parameters"].values()), axis=-1) - - def configure_observed_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=-1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - if summary_outputs is None: - return self.configure_observed_variables(data) - - return summary_outputs - - def configure_summary_conditions(self, data: dict): - return None diff --git a/bayesflow/experimental/amortizers/point_amortizer.py b/bayesflow/experimental/amortizers/point_amortizer.py deleted file mode 100644 index f9ea5484..00000000 --- a/bayesflow/experimental/amortizers/point_amortizer.py +++ /dev/null @@ -1,20 +0,0 @@ -import keras - -from .amortizer import Amortizer - - -class AmortizedPointEstimator(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["parameters"].values()), axis=1) - - def configure_observed_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - if summary_outputs is None: - return self.configure_observed_variables(data) - - return summary_outputs - - def configure_summary_conditions(self, data: dict): - return None From b7670cc1ff2c71dd6abe30aaff783f43bda9984f Mon Sep 17 00:00:00 2001 From: lars Date: Wed, 5 Jun 2024 16:49:23 +0200 Subject: [PATCH 03/24] remove necessary batches_per_epoch from datasets --- bayesflow/experimental/datasets/offline_dataset.py | 8 ++++---- bayesflow/experimental/datasets/online_dataset.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/bayesflow/experimental/datasets/offline_dataset.py b/bayesflow/experimental/datasets/offline_dataset.py index 1d1b7702..9315a9e6 100644 --- a/bayesflow/experimental/datasets/offline_dataset.py +++ b/bayesflow/experimental/datasets/offline_dataset.py @@ -1,5 +1,6 @@ import keras +import math from bayesflow.experimental.utils import nested_getitem @@ -9,13 +10,12 @@ class OfflineDataset(keras.utils.PyDataset): A dataset that is pre-simulated and stored in memory. """ # TODO: fix - def __init__(self, data: dict, batch_size: int, batches_per_epoch: int, **kwargs): + def __init__(self, data: dict, batch_size: int, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size - self.batches_per_epoch = batches_per_epoch self.data = data - self.indices = keras.ops.arange(batch_size * batches_per_epoch) + self.indices = keras.ops.arange(len(data[next(iter(data.keys()))])) self.shuffle() @@ -27,7 +27,7 @@ def __getitem__(self, item: int) -> (dict, dict): return data, {} def __len__(self) -> int: - return self.batches_per_epoch + return math.ceil(len(self.indices) / self.batch_size) def on_epoch_end(self) -> None: self.shuffle() diff --git a/bayesflow/experimental/datasets/online_dataset.py b/bayesflow/experimental/datasets/online_dataset.py index a4dd830b..4a096b00 100644 --- a/bayesflow/experimental/datasets/online_dataset.py +++ b/bayesflow/experimental/datasets/online_dataset.py @@ -8,14 +8,14 @@ class OnlineDataset(keras.utils.PyDataset): """ A dataset that is generated on-the-fly. """ - def __init__(self, joint_distribution: JointDistribution, batch_size: int, **kwargs): + def __init__(self, distribution, batch_size: int, **kwargs): super().__init__(**kwargs) - self.joint_distribution = joint_distribution + self.distribution = distribution self.batch_size = batch_size def __getitem__(self, item: int) -> (dict, dict): """ Sample a batch of data from the joint distribution unconditionally """ - data = self.joint_distribution.sample((self.batch_size,)) + data = self.distribution.sample((self.batch_size,)) return data, {} @property From 988b6f294a0b29990af54dd830151b2d458095cc Mon Sep 17 00:00:00 2001 From: lars Date: Wed, 5 Jun 2024 16:51:19 +0200 Subject: [PATCH 04/24] add test_datasets for @rusty-electron --- tests/test_datasets/conftest.py | 91 ++++++++++++++++++---- tests/test_datasets/test_datasets.py | 35 +++++++++ tests/test_datasets/test_online_dataset.py | 2 - 3 files changed, 109 insertions(+), 19 deletions(-) create mode 100644 tests/test_datasets/test_datasets.py delete mode 100644 tests/test_datasets/test_online_dataset.py diff --git a/tests/test_datasets/conftest.py b/tests/test_datasets/conftest.py index f882b5a1..8be74c18 100644 --- a/tests/test_datasets/conftest.py +++ b/tests/test_datasets/conftest.py @@ -1,33 +1,90 @@ import keras -import keras.random import pytest -import bayesflow.experimental as bf + +@pytest.fixture() +def batch_size(): + return 16 -# TODO: do this last when the implementation of multiprocessing dataloading pipelines is done +@pytest.fixture(params=["online_dataset", "offline_dataset"]) +def dataset(request, online_dataset, offline_dataset): + return request.getfixturevalue(request.param) @pytest.fixture() -def joint_distribution(): - class JointDistribution: - def sample(self, batch_shape): - return dict(x=keras.random.normal(batch_shape + (2,))) +def model(): + class Model(keras.Model): + def call(self, *args, **kwargs): + pass + + def compute_loss(self, **kwargs): + return keras.ops.zeros(()) - def log_prob(self, x): - raise NotImplementedError + model = Model() + model.compile(optimizer=None) - return JointDistribution() + return model @pytest.fixture() -def online_dataset(joint_distribution): - return bf.datasets.OnlineDataset(joint_distribution, batch_size=1) +def offline_dataset(simulator, batch_size, workers, use_multiprocessing): + from bayesflow.experimental import OfflineDataset + + # TODO: there is a bug in keras where if len(dataset) == 1 batch + # fit will error because no logs are generated + # the single batch is then skipped entirely + data = simulator.sample((batch_size * 2,)) + return OfflineDataset(data, batch_size=batch_size, workers=workers, use_multiprocessing=use_multiprocessing) + + +@pytest.fixture() +def online_dataset(simulator, batch_size, workers, use_multiprocessing): + from bayesflow.experimental import OnlineDataset + + return OnlineDataset(simulator, batch_size=batch_size, workers=workers, use_multiprocessing=use_multiprocessing) + + +# needs to be global for pickle to work + +from bayesflow.experimental.simulation.decorators.distribution_decorator import DistributionDecorator as make_distribution + + +class Simulator: + def sample(self, batch_shape): + return dict(x=keras.random.normal(batch_shape + (2,))) + + +@make_distribution(is_batched=True) +def batched_decorated_simulator(batch_shape): + return dict(x=keras.random.normal(batch_shape + (2,))) + + +@make_distribution(is_batched=False) +def unbatched_decorated_simulator(): + return dict(x=keras.random.normal((2,))) + + +@pytest.fixture(params=["class", "batched_decorator", "unbatched_decorator"]) +def simulator(request): + if request.param == "class": + simulator = Simulator() + elif request.param == "batched_decorator": + simulator = batched_decorated_simulator + elif request.param == "unbatched_decorator": + simulator = unbatched_decorated_simulator + else: + raise NotImplementedError + + return simulator + + +@pytest.fixture(params=[True, False]) +def use_multiprocessing(request): + return request.param -# @pytest.fixture() -# def offline_dataset(tmp_path, joint_distribution): -# samples = joint_distribution.sample((32,)) -# ... # ? -# return bf.datasets.OfflineDataset(joint_distribution, batch_size=1, batches_per_epoch=None) +@pytest.fixture(params=[1, 2]) +def workers(request): + return request.param diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py new file mode 100644 index 00000000..f1fe16bc --- /dev/null +++ b/tests/test_datasets/test_datasets.py @@ -0,0 +1,35 @@ + +import keras +import pickle +import pytest + + +@pytest.mark.skip(reason="WIP") +def test_dataset_is_picklable(dataset): + pickled = pickle.loads(pickle.dumps(dataset)) + + assert type(pickled) is type(dataset) + + samples = dataset[0] # tuple of (x, y) + samples = samples[0] # dict of {param_name: param_value} + samples = next(iter(samples.values())) # first param value + + pickled_samples = pickled[0] + pickled_samples = pickled_samples[0] + pickled_samples = next(iter(pickled_samples.values())) + + assert keras.ops.shape(samples) == keras.ops.shape(pickled_samples) + + +@pytest.mark.skip(reason="WIP") +def test_dataset_works_in_fit(model, dataset): + model.fit(dataset, epochs=1, steps_per_epoch=1) + + +@pytest.mark.skip(reason="WIP") +def test_dataset_returns_batch(dataset, batch_size): + samples = dataset[0] # tuple of (x, y) + samples = samples[0] # dict of {param_name: param_value} + samples = next(iter(samples.values())) # first param value + + assert keras.ops.shape(samples)[0] == batch_size diff --git a/tests/test_datasets/test_online_dataset.py b/tests/test_datasets/test_online_dataset.py deleted file mode 100644 index b6dbda7f..00000000 --- a/tests/test_datasets/test_online_dataset.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO - From e4d5b648184af19490ade67796dd17bfb4cfb232 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Thu, 6 Jun 2024 05:33:18 -0400 Subject: [PATCH 05/24] Semanticize, add docs, and fix logic --- .../experimental/amortizers/amortizer.py | 70 ++++++++++++++++--- 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/bayesflow/experimental/amortizers/amortizer.py b/bayesflow/experimental/amortizers/amortizer.py index 3b628628..5d6e44f5 100644 --- a/bayesflow/experimental/amortizers/amortizer.py +++ b/bayesflow/experimental/amortizers/amortizer.py @@ -1,5 +1,6 @@ import keras +from keras import ops from keras.saving import ( deserialize_keras_object, register_keras_serializable, @@ -13,27 +14,74 @@ @register_keras_serializable(package="bayesflow.amortizers") class Amortizer(BaseAmortizer): - def __init__(self, inferred_variables: list[str], observed_variables: list[str], inference_conditions: list[str] = None, summary_conditions: list[str] = None, **kwargs): + def __init__( + self, + inference_variables: list[str], + inference_conditions: list[str] = None, + summary_variables: list[str] = None, + summary_conditions: list[str] = None, + **kwargs + ): + """ The main workhorse for learning amortized neural approximators for distributions arising + in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal + likelihoods). + + The complete semantics of this class allow for flexible estimation of the following distribution: + + Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions), + + where all quantities to the right of the "given" symbol | are optional and H refers to the optional + summary /embedding network used to compress high-dimensional data into lower-dimensional summary + vectors. Some examples are provided below. + + Parameters + ---------- + inference_variables: list[str] + A list of variable names indicating the quantities to be inferred / learned by the approximator, + e.g., model parameters when approximating the Bayesian posterior or observables when approximating + a likelihood density. + inference_conditions: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + distribution over inference variables directly, that is, without passing through the summary network. + summary_variables: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + distribution over inference variables after passing through the summary network (i.e., undergoing a + learnable transformation / dimensionality reduction). For instance, non-vector quantities (e.g., + sets or time-series) in posterior inference will typically qualify as summary variables. In addition, + these quantities may involve learnable distributions on their own. + summary_conditions: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + optional summary network, e.g., when the summary network accepts further conditions that do not + conform to the semantics of summary variable (i.e., need not be embedded or their distribution + needs not be learned). + + #TODO add citations + + Examples + ------- + #TODO + """ + super().__init__(**kwargs) - self.inferred_variables = inferred_variables - self.observed_variables = observed_variables + self.inference_variables = inference_variables self.inference_conditions = inference_conditions or [] + self.summary_variables = summary_variables or [] self.summary_conditions = summary_conditions or [] - def configure_inferred_variables(self, data: dict[str, Tensor]) -> Tensor: - return keras.ops.concatenate([data[key] for key in self.inferred_variables]) - - def configure_observed_variables(self, data: dict[str, Tensor]) -> Tensor: - return keras.ops.concatenate([data[key] for key in self.observed_variables]) + def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor: + return ops.concatenate([data[key] for key in self.inference_variables]) def configure_inference_conditions(self, data: dict[str, Tensor]) -> Tensor | None: if not self.inference_conditions: return None + return ops.concatenate([data[key] for key in self.inference_conditions]) - return keras.ops.concatenate([data[key] for key in self.inference_conditions]) + def configure_summary_variables(self, data: dict[str, Tensor]) -> Tensor | None: + if not self.summary_variables: + return None + return ops.concatenate([data[key] for key in self.summary_variables]) def configure_summary_conditions(self, data: dict[str, Tensor]) -> Tensor | None: if not self.summary_conditions: return None - - return keras.ops.concatenate([data[key] for key in self.summary_conditions]) + return ops.concatenate([data[key] for key in self.summary_conditions]) From ed07fb3be53e5de353b505988dda08851c7c0c43 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 12:23:19 +0200 Subject: [PATCH 06/24] singledispatch methods as replacement for finders --- bayesflow/experimental/utils/__init__.py | 6 +- .../experimental/utils/dispatch/__init__.py | 2 + .../utils/dispatch/find_distribution.py | 24 +++++++ .../utils/dispatch/find_network.py | 31 +++++++++ .../utils/dispatch/find_pooling.py | 38 +++++++++++ bayesflow/experimental/utils/finders.py | 65 ------------------- 6 files changed, 100 insertions(+), 66 deletions(-) create mode 100644 bayesflow/experimental/utils/dispatch/__init__.py create mode 100644 bayesflow/experimental/utils/dispatch/find_distribution.py create mode 100644 bayesflow/experimental/utils/dispatch/find_network.py create mode 100644 bayesflow/experimental/utils/dispatch/find_pooling.py delete mode 100644 bayesflow/experimental/utils/finders.py diff --git a/bayesflow/experimental/utils/__init__.py b/bayesflow/experimental/utils/__init__.py index cce43389..476da980 100644 --- a/bayesflow/experimental/utils/__init__.py +++ b/bayesflow/experimental/utils/__init__.py @@ -1,3 +1,7 @@ from .dictutils import nested_getitem, keras_kwargs -from .finders import find_distribution, find_network, find_pooling +from .dispatch import ( + find_distribution, + find_network, + find_pooling, +) diff --git a/bayesflow/experimental/utils/dispatch/__init__.py b/bayesflow/experimental/utils/dispatch/__init__.py new file mode 100644 index 00000000..808622a5 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/__init__.py @@ -0,0 +1,2 @@ + +from .find_network import find_network diff --git a/bayesflow/experimental/utils/dispatch/find_distribution.py b/bayesflow/experimental/utils/dispatch/find_distribution.py new file mode 100644 index 00000000..2c87a28a --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_distribution.py @@ -0,0 +1,24 @@ + +from functools import singledispatch + + +@singledispatch +def find_distribution(arg, **kwargs): + raise TypeError(f"Cannot infer distribution from {arg!r}.") + + +@find_distribution.register +def _(name: str, **kwargs): + match name.lower(): + case "normal": + from bayesflow.experimental.distributions import DiagonalNormal + distribution = DiagonalNormal(**kwargs) + case other: + raise ValueError(f"Unsupported distribution name '{other}'.") + + return distribution + + +@find_distribution.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) diff --git a/bayesflow/experimental/utils/dispatch/find_network.py b/bayesflow/experimental/utils/dispatch/find_network.py new file mode 100644 index 00000000..12d58952 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_network.py @@ -0,0 +1,31 @@ + +import keras + +from functools import singledispatch + + +@singledispatch +def find_network(arg, **kwargs): + raise TypeError(f"Cannot infer network from {arg!r}.") + + +@find_network.register +def _(name: str, **kwargs): + match name.lower(): + case "resnet": + from bayesflow.experimental.networks import ResNet + network = ResNet(**kwargs) + case other: + raise ValueError(f"Unsupported network name: '{other}'.") + + return network + + +@find_network.register +def _(network: keras.Layer): + return network + + +@find_network.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) diff --git a/bayesflow/experimental/utils/dispatch/find_pooling.py b/bayesflow/experimental/utils/dispatch/find_pooling.py new file mode 100644 index 00000000..398dbbb4 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_pooling.py @@ -0,0 +1,38 @@ + +import keras + +from functools import singledispatch + + +@singledispatch +def find_pooling(arg, **kwargs): + raise TypeError(f"Cannot infer pooling from {arg!r}.") + + +@find_pooling.register +def _(name: str, **kwargs): + match name.lower(): + case "mean" | "avg" | "average": + pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2)) + case "max": + pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2)) + case "min": + pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) + case "learnable" | "pma" | "attention": + from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention + pooling = PoolingByMultiheadAttention(**kwargs) + case other: + raise ValueError(f"Unsupported pooling name: '{other}'.") + + return pooling + + +@find_pooling.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) + + +@find_pooling.register +def _(pooling: keras.Layer): + return pooling + diff --git a/bayesflow/experimental/utils/finders.py b/bayesflow/experimental/utils/finders.py deleted file mode 100644 index ace72f4f..00000000 --- a/bayesflow/experimental/utils/finders.py +++ /dev/null @@ -1,65 +0,0 @@ - -from functools import partial - -import keras - - -def find_distribution(distribution: str | type, **kwargs): - # TODO -> return type - match distribution: - case str() as name: - match name.lower(): - case "normal": - from bayesflow.experimental.distributions import DiagonalNormal - distribution = DiagonalNormal(**kwargs) - case other: - raise ValueError(f"Unsupported distribution name: '{other}'.") - case type() as constructor: - distribution = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer distribution from {other!r}.") - - return distribution - - -def find_network(network: str | keras.Layer | type, **kwargs) -> keras.Layer: - match network: - case str() as name: - match name.lower(): - case "resnet": - from bayesflow.experimental.networks import ResNet - network = ResNet(**kwargs) - case other: - raise ValueError(f"Unsupported network name: '{other}'.") - case keras.Layer() as network: - pass - case type() as constructor: - network = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer network from {other!r}.") - - return network - - -def find_pooling(pooling: str | keras.Layer | type, **kwargs) -> keras.Layer: - match pooling: - case str() as name: - match name.lower(): - case "mean" | "avg": - pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2)) - case "max": - pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2)) - case "min": - pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) - case "learnable" | "pma": - from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention - pooling = PoolingByMultiheadAttention(**kwargs) - case other: - raise ValueError(f"Unsupported pooling type: '{other}'.") - case keras.Layer() as pooling: - pass - case type() as constructor: - pooling = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer pooling type from {other!r}.") - return pooling From 34ec07d7730d8493563170a534671624c4bcc5dd Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Thu, 6 Jun 2024 07:18:51 -0400 Subject: [PATCH 07/24] Rename, semanticize, clean up, and ensure correct condition propagation --- bayesflow/experimental/amortizers/__init__.py | 3 - .../amortizers/joint_amortizer.py | 31 ---- .../experimental/approximators/__init__.py | 3 + .../approximator.py} | 26 +++- .../base_approximator.py} | 147 ++++++++++-------- .../approximators/joint_approximator.py | 31 ++++ 6 files changed, 143 insertions(+), 98 deletions(-) delete mode 100644 bayesflow/experimental/amortizers/__init__.py delete mode 100644 bayesflow/experimental/amortizers/joint_amortizer.py create mode 100644 bayesflow/experimental/approximators/__init__.py rename bayesflow/experimental/{amortizers/amortizer.py => approximators/approximator.py} (84%) rename bayesflow/experimental/{amortizers/base_amortizer.py => approximators/base_approximator.py} (62%) create mode 100644 bayesflow/experimental/approximators/joint_approximator.py diff --git a/bayesflow/experimental/amortizers/__init__.py b/bayesflow/experimental/amortizers/__init__.py deleted file mode 100644 index 6d997d8e..00000000 --- a/bayesflow/experimental/amortizers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from .amortizer import Amortizer -from .joint_amortizer import JointAmortizer diff --git a/bayesflow/experimental/amortizers/joint_amortizer.py b/bayesflow/experimental/amortizers/joint_amortizer.py deleted file mode 100644 index c7eecd69..00000000 --- a/bayesflow/experimental/amortizers/joint_amortizer.py +++ /dev/null @@ -1,31 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class JointAmortizer(keras.Model): - def __init__(self, **amortizers: Amortizer): - super().__init__() - self.amortizers = amortizers - - def build(self, input_shape): - for amortizer in self.amortizers.values(): - amortizer.build(input_shape) - - def call(self, *args, **kwargs): - return {name: amortizer(*args, **kwargs) for name, amortizer in self.amortizers.items()} - - def compute_loss(self, *args, **kwargs): - losses = {name: amortizer.compute_loss(*args, **kwargs) for name, amortizer in self.amortizers.items()} - return keras.ops.mean(losses.values(), axis=0) - - def compute_metrics(self, *args, **kwargs): - metrics = {} - - for name, amortizer in self.amortizers.items(): - m = amortizer.compute_metrics(*args, **kwargs) - m = {f"{name}/{key}": value for key, value in m.items()} - metrics |= m - - return metrics diff --git a/bayesflow/experimental/approximators/__init__.py b/bayesflow/experimental/approximators/__init__.py new file mode 100644 index 00000000..1a28ff47 --- /dev/null +++ b/bayesflow/experimental/approximators/__init__.py @@ -0,0 +1,3 @@ + +from .approximator import Approximator +from .joint_approximator import JointApproximator diff --git a/bayesflow/experimental/amortizers/amortizer.py b/bayesflow/experimental/approximators/approximator.py similarity index 84% rename from bayesflow/experimental/amortizers/amortizer.py rename to bayesflow/experimental/approximators/approximator.py index 5d6e44f5..52fdea85 100644 --- a/bayesflow/experimental/amortizers/amortizer.py +++ b/bayesflow/experimental/approximators/approximator.py @@ -9,11 +9,11 @@ from bayesflow.experimental.types import Tensor -from .base_amortizer import BaseAmortizer +from .base_approximator import BaseApproximator -@register_keras_serializable(package="bayesflow.amortizers") -class Amortizer(BaseAmortizer): +@register_keras_serializable(package="bayesflow.approximators") +class Approximator(BaseApproximator): def __init__( self, inference_variables: list[str], @@ -30,6 +30,8 @@ def __init__( Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions), + #TODO - math notation + where all quantities to the right of the "given" symbol | are optional and H refers to the optional summary /embedding network used to compress high-dimensional data into lower-dimensional summary vectors. Some examples are provided below. @@ -68,6 +70,24 @@ def __init__( self.summary_variables = summary_variables or [] self.summary_conditions = summary_conditions or [] + def configure_full_conditions( + self, + summary_outputs: Tensor | None, + inference_conditions: Tensor | None, + ) -> Tensor: + """ + Combine the (optional) inference conditions with the (optional) outputs + of the (optional) summary network. + """ + + if summary_outputs is None: + return inference_conditions + if inference_conditions is None: + return summary_outputs + return keras.ops.concatenate( + (summary_outputs, inference_conditions), axis=-1 + ) + def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor: return ops.concatenate([data[key] for key in self.inference_variables]) diff --git a/bayesflow/experimental/amortizers/base_amortizer.py b/bayesflow/experimental/approximators/base_approximator.py similarity index 62% rename from bayesflow/experimental/amortizers/base_amortizer.py rename to bayesflow/experimental/approximators/base_approximator.py index ed4feacd..70711a90 100644 --- a/bayesflow/experimental/amortizers/base_amortizer.py +++ b/bayesflow/experimental/approximators/base_approximator.py @@ -10,37 +10,25 @@ from bayesflow.experimental.networks import InferenceNetwork, SummaryNetwork -class BaseAmortizer(keras.Model): - def __init__(self, inference_network: InferenceNetwork, summary_network: SummaryNetwork = None, **kwargs): +class BaseApproximator(keras.Model): + def __init__( + self, + inference_network: InferenceNetwork, + summary_network: SummaryNetwork = None, + **kwargs + ): + super().__init__(**kwargs) self.inference_network = inference_network self.summary_network = summary_network def sample(self, num_samples: int, **kwargs) -> dict[str, Tensor]: # TODO - return self.inference_network.sample(num_samples, **kwargs) + return {} def log_prob(self, samples: dict[str, Tensor], **kwargs) -> Tensor: # TODO - samples = self.configure_inferred_variables(samples) - return self.inference_network.log_prob(samples, **kwargs) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "BaseAmortizer": - inference_network = deserialize_keras_object(config.pop("inference_network"), custom_objects=custom_objects) - summary_network = deserialize_keras_object(config.pop("summary_network"), custom_objects=custom_objects) - - return cls(inference_network, summary_network, **config) - - def get_config(self): - base_config = super().get_config() - - config = { - "inference_network": serialize_keras_object(self.inference_network), - "summary_network": serialize_keras_object(self.summary_network), - } - - return base_config | config + return {} def call(self, *, training=False, **data): if not training: @@ -55,53 +43,73 @@ def call(self, *, training=False, **data): return None def compute_loss(self, **data): - inferred_variables = self.configure_inferred_variables(data) - observed_variables = self.configure_observed_variables(data) + + # Configure dict outputs into tensors + inference_variables = self.configure_inference_variables(data) inference_conditions = self.configure_inference_conditions(data) + summary_variables = self.configure_summary_variables(data) summary_conditions = self.configure_summary_conditions(data) + # Obtain summary outputs and summary loss (if present) if self.summary_network: - summary_loss = self.summary_network.compute_loss( - observed_variables=observed_variables, - summary_conditions=summary_conditions, - ) + summary_outputs = self.summary_network(summary_variables, summary_conditions) + summary_loss = self.summary_network.compute_loss(summary_outputs) else: + summary_outputs = None summary_loss = keras.ops.zeros(()) + # Combine summary outputs and inference conditions + full_conditions = self.configure_full_conditions(summary_outputs, inference_conditions) + + # Compute inference loss inference_loss = self.inference_network.compute_loss( - inferred_variables=inferred_variables, - inference_conditions=inference_conditions, + targets=inference_variables, + conditions=full_conditions, ) return inference_loss + summary_loss def compute_metrics(self, **data): + #TODO base_metrics = super().compute_metrics(**data) + return base_metrics + + # inference_variables = self.configure_inference_variables(data) + # inference_conditions = self.configure_inference_conditions(data) + # summary_variables = self.configure_summary_variables(data) + # summary_conditions = self.configure_summary_conditions(data) + # + # if self.summary_network: + # summary_metrics = self.summary_network.compute_metrics( + # summary_variables=summary_variables, + # summary_conditions=summary_conditions, + # ) + # else: + # summary_metrics = {} + # + # inference_metrics = self.inference_network.compute_metrics( + # inference_variables=inference_variables, + # conditions=conditions, + # ) + # + # summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} + # inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} + + # return base_metrics | inference_metrics | summary_metrics + + def configure_full_conditions( + self, + summary_outputs: Tensor | None, + inference_conditions: Tensor | None, + ) -> Tensor: + """ + Combine the (optional) inference conditions with the (optional) outputs + of the (optional) summary network. + """ - inferred_variables = self.configure_inferred_variables(data) - observed_variables = self.configure_observed_variables(data) - summary_conditions = self.configure_summary_conditions(data) - inference_conditions = self.configure_inference_conditions(data) - - if self.summary_network: - summary_metrics = self.summary_network.compute_metrics( - observed_variables=observed_variables, - summary_conditions=summary_conditions, - ) - else: - summary_metrics = {} - - inference_metrics = self.inference_network.compute_metrics( - inferred_variables=inferred_variables, - inference_conditions=inference_conditions, - ) - - summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} - inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} - - return base_metrics | inference_metrics | summary_metrics + raise NotImplementedError - def configure_inferred_variables(self, data: dict) -> any: + def configure_inference_variables(self, data: dict) -> any: """ Return the inferred variables, given the data. Inferred variables are passed as input to the inference network. @@ -112,10 +120,12 @@ def configure_inferred_variables(self, data: dict) -> any: """ raise NotImplementedError - def configure_observed_variables(self, data: dict) -> any: + def configure_inference_conditions(self, data: dict) -> any: """ - Return the observed variables, given the data. - Observed variables are passed as input to the summary and/or inference networks. + Return the inference conditions, given the data. + Inference conditions are passed as conditional input to the inference network. + + If summary outputs are provided, they should be concatenated to the return value. This method must be efficient and deterministic. Best practice is to prepare the output in dataset.__getitem__, @@ -123,12 +133,10 @@ def configure_observed_variables(self, data: dict) -> any: """ raise NotImplementedError - def configure_inference_conditions(self, data: dict) -> any: + def configure_summary_variables(self, data: dict) -> any: """ - Return the inference conditions, given the data. - Inference conditions are passed as conditional input to the inference network. - - If summary outputs are provided, they should be concatenated to the return value. + Return the observed variables, given the data. + Observed variables are passed as input to the summary and/or inference networks. This method must be efficient and deterministic. Best practice is to prepare the output in dataset.__getitem__, @@ -146,3 +154,20 @@ def configure_summary_conditions(self, data: dict) -> any: which is run in a worker process, and then simply fetch a key from the data dictionary here. """ raise NotImplementedError + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "BaseApproximator": + inference_network = deserialize_keras_object(config.pop("inference_network"), custom_objects=custom_objects) + summary_network = deserialize_keras_object(config.pop("summary_network"), custom_objects=custom_objects) + + return cls(inference_network, summary_network, **config) + + def get_config(self): + base_config = super().get_config() + + config = { + "inference_network": serialize_keras_object(self.inference_network), + "summary_network": serialize_keras_object(self.summary_network), + } + + return base_config | config diff --git a/bayesflow/experimental/approximators/joint_approximator.py b/bayesflow/experimental/approximators/joint_approximator.py new file mode 100644 index 00000000..034b155a --- /dev/null +++ b/bayesflow/experimental/approximators/joint_approximator.py @@ -0,0 +1,31 @@ + +import keras + +from .approximator import Approximator + + +class JointApproximator(keras.Model): + def __init__(self, **approximators: Approximator): + super().__init__() + self.approximators = approximators + + def build(self, input_shape): + for approximator in self.approximators.values(): + approximator.build(input_shape) + + def call(self, *args, **kwargs): + return {name: approximator(*args, **kwargs) for name, approximator in self.approximators.items()} + + def compute_loss(self, *args, **kwargs): + losses = {name: amortizer.compute_loss(*args, **kwargs) for name, amortizer in self.approximators.items()} + return keras.ops.mean(losses.values(), axis=0) + + def compute_metrics(self, *args, **kwargs): + metrics = {} + + for name, approximator in self.approximators.items(): + m = approximator.compute_metrics(*args, **kwargs) + m = {f"{name}/{key}": value for key, value in m.items()} + metrics |= m + + return metrics From 797e1e7514837e440b2728a9e1f91e2b41f841b2 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Thu, 6 Jun 2024 07:19:30 -0400 Subject: [PATCH 08/24] Small semantic changes --- .../networks/coupling_flow/actnorm.py | 4 +--- .../networks/coupling_flow/coupling_flow.py | 17 +++++++++++++++-- .../experimental/networks/inference_network.py | 4 ++-- .../experimental/networks/summary_network.py | 4 ++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/bayesflow/experimental/networks/coupling_flow/actnorm.py b/bayesflow/experimental/networks/coupling_flow/actnorm.py index 391d5f08..7a727add 100644 --- a/bayesflow/experimental/networks/coupling_flow/actnorm.py +++ b/bayesflow/experimental/networks/coupling_flow/actnorm.py @@ -1,8 +1,6 @@ from keras import ops -from keras.saving import ( - register_keras_serializable, -) +from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor from .invertible_layer import InvertibleLayer diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index 4ab9e1fc..8f64d1f6 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -69,12 +69,25 @@ def build(self, xz_shape, conditions_shape=None): else: self.call(keras.KerasTensor(xz_shape), conditions=keras.KerasTensor(conditions_shape)) - def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def call( + self, + xz: Tensor, + conditions: Tensor = None, + inverse: bool = False, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + if inverse: return self._inverse(xz, conditions=conditions, **kwargs) return self._forward(xz, conditions=conditions, **kwargs) - def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def _forward( + self, + x: Tensor, + conditions: Tensor = None, + jacobian: bool = False, + **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) for layer in self._layers: diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index c178e38f..1fa10be5 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -36,8 +36,8 @@ def log_prob(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: log_prob = self.base_distribution.log_prob(samples) return log_prob + log_det - def compute_loss(self, inferred_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: + def compute_loss(self, targets: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: raise NotImplementedError - def compute_metrics(self, inferred_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> dict: + def compute_metrics(self, targets: Tensor, conditions: Tensor = None, **kwargs) -> dict: return {} diff --git a/bayesflow/experimental/networks/summary_network.py b/bayesflow/experimental/networks/summary_network.py index e51bfe83..90a86bd5 100644 --- a/bayesflow/experimental/networks/summary_network.py +++ b/bayesflow/experimental/networks/summary_network.py @@ -5,8 +5,8 @@ class SummaryNetwork(keras.Layer): - def compute_loss(self, observed_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> Tensor: + def compute_loss(self, summary_outputs: Tensor, **kwargs) -> Tensor: return keras.ops.zeros(()) - def compute_metrics(self, observed_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> dict: + def compute_metrics(self, summary_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> dict: return {} From 3da7be800f93032a9f96b229d9e595941f4311f4 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 13:43:25 +0200 Subject: [PATCH 09/24] backend-specific approximators template --- .../backend_approximators/__init__.py | 2 ++ .../backend_approximators/approximator.py | 14 ++++++++++++++ .../base_approximator.py | 16 ++++++++++++++++ .../backend_approximators/jax_approximator.py | 8 ++++++++ .../numpy_approximator.py | 8 ++++++++ .../tensorflow_approximator.py | 8 ++++++++ .../torch_approximator.py | 19 +++++++++++++++++++ 7 files changed, 75 insertions(+) create mode 100644 bayesflow/experimental/backend_approximators/__init__.py create mode 100644 bayesflow/experimental/backend_approximators/approximator.py create mode 100644 bayesflow/experimental/backend_approximators/base_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/jax_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/numpy_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/tensorflow_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/torch_approximator.py diff --git a/bayesflow/experimental/backend_approximators/__init__.py b/bayesflow/experimental/backend_approximators/__init__.py new file mode 100644 index 00000000..3e50858b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/__init__.py @@ -0,0 +1,2 @@ + +from .approximator import Approximator diff --git a/bayesflow/experimental/backend_approximators/approximator.py b/bayesflow/experimental/backend_approximators/approximator.py new file mode 100644 index 00000000..76112757 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/approximator.py @@ -0,0 +1,14 @@ + +import keras + +match keras.backend.backend(): + case "jax": + from .jax_approximator import JAXApproximator as Approximator + case "numpy": + from .numpy_approximator import NumpyApproximator as Approximator + case "tensorflow": + from .tensorflow_approximator import TensorFlowApproximator as Approximator + case "torch": + from .torch_approximator import TorchApproximator as Approximator + case other: + raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.") diff --git a/bayesflow/experimental/backend_approximators/base_approximator.py b/bayesflow/experimental/backend_approximators/base_approximator.py new file mode 100644 index 00000000..e17aa9ba --- /dev/null +++ b/bayesflow/experimental/backend_approximators/base_approximator.py @@ -0,0 +1,16 @@ + +import keras + +from bayesflow.experimental.types import Tensor + + +class BaseApproximator(keras.Model): + def train_step(self, data): + raise NotImplementedError + + # noinspection PyMethodOverriding + def compute_metrics(self, data: dict[str, Tensor], mode: str = "training") -> Tensor: + raise NotImplementedError + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError(f"Use compute_metrics()['loss'] instead.") diff --git a/bayesflow/experimental/backend_approximators/jax_approximator.py b/bayesflow/experimental/backend_approximators/jax_approximator.py new file mode 100644 index 00000000..376f597b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/jax_approximator.py @@ -0,0 +1,8 @@ + +import jax + +from .base_approximator import BaseApproximator + + +class JAXApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/numpy_approximator.py b/bayesflow/experimental/backend_approximators/numpy_approximator.py new file mode 100644 index 00000000..ebedd52c --- /dev/null +++ b/bayesflow/experimental/backend_approximators/numpy_approximator.py @@ -0,0 +1,8 @@ + +import numpy as np + +from .base_approximator import BaseApproximator + + +class NumpyApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/tensorflow_approximator.py b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py new file mode 100644 index 00000000..f6d10d41 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py @@ -0,0 +1,8 @@ + +import tensorflow as tf + +from .base_approximator import BaseApproximator + + +class TensorFlowApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/torch_approximator.py b/bayesflow/experimental/backend_approximators/torch_approximator.py new file mode 100644 index 00000000..78d1729d --- /dev/null +++ b/bayesflow/experimental/backend_approximators/torch_approximator.py @@ -0,0 +1,19 @@ + +import torch + + +from .base_approximator import BaseApproximator + + +class TorchApproximator(BaseApproximator): + def train_step(self, data): + with torch.enable_grad(): + metrics = self.compute_metrics(data, mode="training") + + loss = metrics.pop("loss") + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return metrics From 96a01f4d3de153ef1011e0f7c85418bed66abf4f Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 14:21:09 +0200 Subject: [PATCH 10/24] fix workflow auto-active-base -> auto-activate-base how did this slip under the radar? --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 623b8e20..dd1b62b9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,7 +28,7 @@ jobs: activate-environment: bayesflow environment-file: environment.yaml python-version: ${{ matrix.python-version }} - auto-active-base: false + auto-activate-base: false - name: Install JAX if: ${{ matrix.backend == 'jax' }} From 1e62f1433a566c6c3fde107aa7f3d5b2a894a7c2 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:21:54 +0200 Subject: [PATCH 11/24] add test-backend-case-distinction to workflow --- .github/workflows/tests.yml | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index dd1b62b9..17062985 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,6 +29,7 @@ jobs: environment-file: environment.yaml python-version: ${{ matrix.python-version }} auto-activate-base: false + use-mamba: true - name: Install JAX if: ${{ matrix.backend == 'jax' }} @@ -59,6 +60,22 @@ jobs: conda config --show printenv | sort - - name: Run tests + - name: Run JAX Tests + if: ${{ matrix.backend == 'jax' }} + run: | + python -m pytest tests/ -n auto -v -m "not (numpy or tensorflow or torch)" + + - name: Run NumPy Tests + if: ${{ matrix.backend == 'numpy' }} + run: | + python -m pytest tests/ -n auto -v -m "not (jax or tensorflow or torch)" + + - name: Run TensorFlow Tests + if: ${{ matrix.backend == 'tensorflow' }} + run: | + python -m pytest tests/ -n auto -v -m "not (jax or numpy or torch)" + + - name: Run PyTorch Tests + if: ${{ matrix.backend == 'torch' }} run: | - python -m pytest tests/ -n auto -v + python -m pytest tests/ -n auto -v -m "not (jax or numpy or tensorflow)" From afc27c93c27b9bdcbdb9a0064deb4cf788c94a33 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:31:54 +0200 Subject: [PATCH 12/24] testing workflow --- .github/workflows/tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 17062985..4ad56c05 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,11 +25,8 @@ jobs: - name: Set up Conda uses: conda-incubator/setup-miniconda@v3 with: - activate-environment: bayesflow environment-file: environment.yaml python-version: ${{ matrix.python-version }} - auto-activate-base: false - use-mamba: true - name: Install JAX if: ${{ matrix.backend == 'jax' }} From 7b26603c8173a0dfb8be60957240dc86cd0b4a20 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:35:29 +0200 Subject: [PATCH 13/24] use login mode shell for workflows --- .github/workflows/tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4ad56c05..26b2abbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,9 @@ jobs: os: [ubuntu-latest, windows-latest] python-version: ["3.10", "3.11"] backend: ["jax", "numpy", "tensorflow", "torch"] + defaults: + run: + shell: bash -el {0} steps: - name: Checkout code From 2653090bfbedbf1b5799c13cfe0bd3ee61986d73 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:38:38 +0200 Subject: [PATCH 14/24] use conda for environment backend variables --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 26b2abbb..a003d444 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,22 +35,22 @@ jobs: if: ${{ matrix.backend == 'jax' }} run: | pip install -U "jax[cpu]" - export KERAS_BACKEND=jax + conda env config vars set KERAS_BACKEND=jax - name: Install NumPy if: ${{ matrix.backend == 'numpy' }} run: | conda install numpy - export KERAS_BACKEND=numpy + conda env config vars set KERAS_BACKEND=numpy - name: Install Tensorflow if: ${{ matrix.backend == 'tensorflow' }} run: | pip install -U tensorflow - export KERAS_BACKEND=tensorflow + conda env config vars set KERAS_BACKEND=tensorflow - name: Install PyTorch if: ${{ matrix.backend == 'torch' }} run: | conda install pytorch torchvision torchaudio cpuonly -c pytorch - export KERAS_BACKEND=torch + conda env config vars set KERAS_BACKEND=torch - name: Show Environment Info run: | From 8318dee69f05c4e98b3d6bba52c7350ac03ea748 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:44:56 +0200 Subject: [PATCH 15/24] set backend environment variable directly --- .github/workflows/tests.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a003d444..d44c1079 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,6 +20,7 @@ jobs: defaults: run: shell: bash -el {0} + env: ${{ matrix.backend }} steps: - name: Checkout code @@ -35,22 +36,18 @@ jobs: if: ${{ matrix.backend == 'jax' }} run: | pip install -U "jax[cpu]" - conda env config vars set KERAS_BACKEND=jax - name: Install NumPy if: ${{ matrix.backend == 'numpy' }} run: | conda install numpy - conda env config vars set KERAS_BACKEND=numpy - name: Install Tensorflow if: ${{ matrix.backend == 'tensorflow' }} run: | pip install -U tensorflow - conda env config vars set KERAS_BACKEND=tensorflow - name: Install PyTorch if: ${{ matrix.backend == 'torch' }} run: | conda install pytorch torchvision torchaudio cpuonly -c pytorch - conda env config vars set KERAS_BACKEND=torch - name: Show Environment Info run: | From b59a0f54d1df95826fca31e5815d818f00798c5d Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:44:56 +0200 Subject: [PATCH 16/24] set backend environment variable directly --- .github/workflows/tests.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a003d444..e8ea0d30 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,6 +20,8 @@ jobs: defaults: run: shell: bash -el {0} + env: + KERAS_BACKEND: ${{ matrix.backend }} steps: - name: Checkout code @@ -35,22 +37,18 @@ jobs: if: ${{ matrix.backend == 'jax' }} run: | pip install -U "jax[cpu]" - conda env config vars set KERAS_BACKEND=jax - name: Install NumPy if: ${{ matrix.backend == 'numpy' }} run: | conda install numpy - conda env config vars set KERAS_BACKEND=numpy - name: Install Tensorflow if: ${{ matrix.backend == 'tensorflow' }} run: | pip install -U tensorflow - conda env config vars set KERAS_BACKEND=tensorflow - name: Install PyTorch if: ${{ matrix.backend == 'torch' }} run: | conda install pytorch torchvision torchaudio cpuonly -c pytorch - conda env config vars set KERAS_BACKEND=torch - name: Show Environment Info run: | From faf3d096525317ec1f366704fda80441bd3ab382 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 15:50:11 +0200 Subject: [PATCH 17/24] fix top-level imports --- bayesflow/experimental/__init__.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index efa50cab..f5b86c10 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -1,20 +1,16 @@ from . import ( - amortizers, - datasets, - diagnostics, - networks, - simulation, + approximators, + datasets, + diagnostics, + distributions, + networks, + simulation, ) -from .amortizers import ( - AmortizedLikelihood, - AmortizedPosterior, - Amortizer, -) +from .approximators import Approximator -from .simulation import ( - distribution, - JointDistribution, +from .datasets import ( + OnlineDataset, + OfflineDataset, ) - From 3e38c7c2a300fe0d4e598c64721710ee0ca26dbf Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:45:09 +0200 Subject: [PATCH 18/24] update assert_layers_equal --- tests/test_networks/test_inference_networks.py | 4 ++-- tests/utils/assertions.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 4340e758..ede8c1ef 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from tests.utils import allclose, assert_models_equal +from tests.utils import allclose, assert_layers_equal def test_build(inference_network, random_samples, random_conditions): @@ -111,4 +111,4 @@ def test_serialize_deserialize(tmp_path, inference_network, random_samples, rand keras.saving.save_model(inference_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") - assert_models_equal(inference_network, loaded) + assert_layers_equal(inference_network, loaded) diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index 9d2a669f..b121bac8 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -14,14 +14,16 @@ def assert_models_equal(model1: keras.Model, model2: keras.Model): def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): + assert layer1.name == layer2.name assert len(layer1.variables) == len(layer2.variables), f"Layers {layer1.name} and {layer2.name} have a different number of variables ({len(layer1.variables)}, {len(layer2.variables)})." assert len(layer1.variables) > 0, f"Layers {layer1.name} and {layer2.name} have no variables." for v1, v2 in zip(layer1.variables, layer2.variables): + assert v1.name == v2.name - if v1.name == "seed_generator_state" and v1.name == v2.name: + if v1.name == "seed_generator_state": # keras issue: https://github.com/keras-team/keras/issues/19796 continue - v1 = keras.ops.convert_to_numpy(v1) - v2 = keras.ops.convert_to_numpy(v2) - assert keras.ops.all(keras.ops.isclose(v1, v2)), f"Variables for {layer1.name} and {layer2.name} are not equal: {v1} != {v2}" + x1 = keras.ops.convert_to_numpy(v1) + x2 = keras.ops.convert_to_numpy(v2) + assert keras.ops.all(keras.ops.isclose(x1, x2)), f"Variable '{v1.name}' for Layer '{layer1.name}' is not equal: {x1} != {x2}" From 76e117b4bfa02667956b1ff7a559ab15bc3ccab4 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:45:32 +0200 Subject: [PATCH 19/24] move test_amortizers -> test_approximators --- .../__init__.py | 0 .../conftest.py | 14 ++++++++++---- .../test_fit.py | 0 3 files changed, 10 insertions(+), 4 deletions(-) rename tests/{test_amortizers => test_approximators}/__init__.py (100%) rename tests/{test_amortizers => test_approximators}/conftest.py (77%) rename tests/{test_amortizers => test_approximators}/test_fit.py (100%) diff --git a/tests/test_amortizers/__init__.py b/tests/test_approximators/__init__.py similarity index 100% rename from tests/test_amortizers/__init__.py rename to tests/test_approximators/__init__.py diff --git a/tests/test_amortizers/conftest.py b/tests/test_approximators/conftest.py similarity index 77% rename from tests/test_amortizers/conftest.py rename to tests/test_approximators/conftest.py index 09b0e0d4..d5dd6f70 100644 --- a/tests/test_amortizers/conftest.py +++ b/tests/test_approximators/conftest.py @@ -19,10 +19,16 @@ def inference_network(): return network -@pytest.fixture(params=[bf.AmortizedPosterior, bf.AmortizedLikelihood]) -def amortizer(request, inference_network, summary_network): - Amortizer = request.param - return Amortizer(inference_network, summary_network) +@pytest.fixture() +def approximator(inference_network, summary_network): + return bf.Approximator( + inference_network=inference_network, + summary_network=summary_network, + inference_variables=[], + inference_conditions=[], + summary_variables=[], + summary_conditions=[], + ) @pytest.fixture() diff --git a/tests/test_amortizers/test_fit.py b/tests/test_approximators/test_fit.py similarity index 100% rename from tests/test_amortizers/test_fit.py rename to tests/test_approximators/test_fit.py From 8e364050d57bdf8f5de21eca2ea35b8f5cef315c Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:46:41 +0200 Subject: [PATCH 20/24] add find_permutation to dispatch --- bayesflow/experimental/utils/__init__.py | 1 + .../experimental/utils/dispatch/__init__.py | 3 ++ .../utils/dispatch/find_permutation.py | 32 +++++++++++++++++++ 3 files changed, 36 insertions(+) create mode 100644 bayesflow/experimental/utils/dispatch/find_permutation.py diff --git a/bayesflow/experimental/utils/__init__.py b/bayesflow/experimental/utils/__init__.py index 476da980..6f46f02b 100644 --- a/bayesflow/experimental/utils/__init__.py +++ b/bayesflow/experimental/utils/__init__.py @@ -3,5 +3,6 @@ from .dispatch import ( find_distribution, find_network, + find_permutation, find_pooling, ) diff --git a/bayesflow/experimental/utils/dispatch/__init__.py b/bayesflow/experimental/utils/dispatch/__init__.py index 808622a5..864f6547 100644 --- a/bayesflow/experimental/utils/dispatch/__init__.py +++ b/bayesflow/experimental/utils/dispatch/__init__.py @@ -1,2 +1,5 @@ +from .find_distribution import find_distribution from .find_network import find_network +from .find_permutation import find_permutation +from .find_pooling import find_pooling diff --git a/bayesflow/experimental/utils/dispatch/find_permutation.py b/bayesflow/experimental/utils/dispatch/find_permutation.py new file mode 100644 index 00000000..c01864b9 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_permutation.py @@ -0,0 +1,32 @@ + +import keras +from functools import singledispatch + + +@singledispatch +def find_permutation(arg, **kwargs): + raise TypeError(f"Cannot infer permutation from {arg!r}.") + + +@find_permutation.register +def _(name: str, **kwargs): + match name.lower(): + case "random": + from bayesflow.experimental.networks.coupling_flow.permutations import RandomPermutation + return RandomPermutation(**kwargs) + case "swap": + from bayesflow.experimental.networks.coupling_flow.permutations import Swap + return Swap(**kwargs) + case "learnable" | "orthogonal": + from bayesflow.experimental.networks.coupling_flow.permutations import OrthogonalPermutation + return OrthogonalPermutation(**kwargs) + + +@find_permutation.register +def _(permutation: keras.Layer, **kwargs): + return permutation + + +@find_permutation.register +def _(none: None, **kwargs): + return None From 110accb29c8563802685a3480374295cc061beb6 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:48:20 +0200 Subject: [PATCH 21/24] fix a bug where CouplingFlow would re-randomize weights when loaded --- .../networks/coupling_flow/coupling_flow.py | 68 +++++++------------ .../coupling_flow/couplings/dual_coupling.py | 14 +++- .../couplings/single_coupling.py | 26 ++++--- 3 files changed, 56 insertions(+), 52 deletions(-) diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index 8f64d1f6..86713af3 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -5,10 +5,9 @@ from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor -from bayesflow.experimental.utils import keras_kwargs +from bayesflow.experimental.utils import find_permutation from .actnorm import ActNorm from .couplings import DualCoupling -from .permutations import OrthogonalPermutation, RandomPermutation, Swap from ..inference_network import InferenceNetwork @@ -41,60 +40,48 @@ def __init__( depth: int = 6, subnet: str = "resnet", transform: str = "affine", - permutation: str = "random", + permutation: str | None = None, use_actnorm: bool = True, **kwargs ): - """TODO""" + # TODO - propagate optional keyword arguments to find_network and ResNet respectively + super().__init__(**kwargs) - super().__init__(**keras_kwargs(kwargs)) + self.depth = depth - self._layers = [] + self.invertible_layers = [] for i in range(depth): if use_actnorm: - self._layers.append(ActNorm()) - self._layers.append(DualCoupling(subnet, transform, **kwargs)) - if permutation.lower() == "random": - self._layers.append(RandomPermutation()) - elif permutation.lower() == "swap": - self._layers.append(Swap()) - elif permutation.lower() == "learnable": - self._layers.append(OrthogonalPermutation()) + self.invertible_layers.append(ActNorm(name=f"ActNorm{i}")) + + self.invertible_layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}")) + + if (p := find_permutation(permutation, name=f"Permutation{i}")) is not None: + self.invertible_layers.append(p) # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) + + xz = keras.KerasTensor(xz_shape) if conditions_shape is None: - self.call(keras.KerasTensor(xz_shape)) + conditions = None else: - self.call(keras.KerasTensor(xz_shape), conditions=keras.KerasTensor(conditions_shape)) + conditions = keras.KerasTensor(conditions_shape) - def call( - self, - xz: Tensor, - conditions: Tensor = None, - inverse: bool = False, **kwargs - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + # build nested layers with forward pass + self.call(xz, conditions=conditions) + def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: if inverse: return self._inverse(xz, conditions=conditions, **kwargs) return self._forward(xz, conditions=conditions, **kwargs) - def _forward( - self, - x: Tensor, - conditions: Tensor = None, - jacobian: bool = False, - **kwargs - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - + def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) - for layer in self._layers: - if isinstance(layer, DualCoupling): - z, det = layer(z, conditions=conditions, inverse=False, **kwargs) - else: - z, det = layer(z, inverse=False, **kwargs) + for layer in self.invertible_layers: + z, det = layer(z, conditions=conditions, inverse=False, **kwargs) log_det += det if jacobian: @@ -104,19 +91,16 @@ def _forward( def _inverse(self, z: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = z log_det = keras.ops.zeros(keras.ops.shape(z)[:-1]) - for layer in reversed(self._layers): - if isinstance(layer, DualCoupling): - x, det = layer(x, conditions=conditions, inverse=True, **kwargs) - else: - x, det = layer(x, inverse=True, **kwargs) + for layer in reversed(self.invertible_layers): + x, det = layer(x, conditions=conditions, inverse=True, **kwargs) log_det += det if jacobian: return x, log_det return x - def compute_loss(self, x: Tensor = None, conditions: Tensor = None, **kwargs): - z, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs) + def compute_loss(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: + z, log_det = self(inference_variables, conditions=inference_conditions, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(z) nll = -keras.ops.mean(log_prob + log_det) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py index d1ea3354..6491d302 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py @@ -16,8 +16,18 @@ def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs): self.coupling2 = SingleCoupling(subnet, transform, **kwargs) self.pivot = None - def build(self, input_shape): - self.pivot = input_shape[-1] // 2 + # noinspection PyMethodOverriding + def build(self, xz_shape, conditions_shape=None): + self.pivot = xz_shape[-1] // 2 + + xz = keras.KerasTensor(xz_shape) + if conditions_shape is None: + conditions = None + else: + conditions = keras.KerasTensor(conditions_shape) + + # build nested layers with forward pass + self.call(xz, conditions=conditions) def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> (Tensor, Tensor): if inverse: diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py index 7d6a08c9..82027d4f 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -17,23 +17,33 @@ class SingleCoupling(InvertibleLayer): """ def __init__( self, - network: str = "resnet", + subnet: str = "resnet", transform: str = "affine", - output_layer_kernel_init: str = "zeros", **kwargs ): super().__init__(**keras_kwargs(kwargs)) - self.output_projector = keras.layers.Dense( - units=None, - kernel_initializer=output_layer_kernel_init, - ) - self.network = find_network(network, **kwargs.get("subnet_kwargs", {})) + + self.network = find_network(subnet, **kwargs.get("subnet_kwargs", {})) self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {})) + output_projector_kwargs = kwargs.get("output_projector_kwargs", {}) + output_projector_kwargs.setdefault("kernel_initializer", "zeros") + self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs) + # noinspection PyMethodOverriding - def build(self, x1_shape, x2_shape): + def build(self, x1_shape, x2_shape, conditions_shape=None): self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] + x1 = keras.KerasTensor(x1_shape) + x2 = keras.KerasTensor(x2_shape) + if conditions_shape is None: + conditions = None + else: + conditions = keras.KerasTensor(conditions_shape) + + # build nested layers with forward pass + self.call(x1, x2, conditions=conditions) + def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> ((Tensor, Tensor), Tensor): if inverse: return self._inverse(x1, x2, conditions=conditions, **kwargs) From 9b1725e4dd3b513bdabbbafc3ac3f8f3b76bf976 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:48:38 +0200 Subject: [PATCH 22/24] clean up --- bayesflow/experimental/networks/inference_network.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index 1fa10be5..9b398f36 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -29,15 +29,16 @@ def _inverse(self, z: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample((num_samples,)) - return self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs) + samples = self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs) + return samples - def log_prob(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: - samples, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs) + def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + samples, log_det = self(samples, conditions=conditions, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(samples) return log_prob + log_det - def compute_loss(self, targets: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + def compute_loss(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: raise NotImplementedError - def compute_metrics(self, targets: Tensor, conditions: Tensor = None, **kwargs) -> dict: + def compute_metrics(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> dict: return {} From 0a00784fc7dda9a792f4cf344245013af0867fad Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 18:55:48 +0200 Subject: [PATCH 23/24] Simplify ResNet for now, move ConfigurableMLP into MLP --- bayesflow/experimental/networks/__init__.py | 1 + .../experimental/networks/mlp/__init__.py | 2 + .../networks/{resnet => mlp}/hidden_block.py | 6 +- bayesflow/experimental/networks/mlp/mlp.py | 80 ++++++++++++++++++ .../experimental/networks/resnet/resnet.py | 81 ++++--------------- 5 files changed, 103 insertions(+), 67 deletions(-) create mode 100644 bayesflow/experimental/networks/mlp/__init__.py rename bayesflow/experimental/networks/{resnet => mlp}/hidden_block.py (90%) create mode 100644 bayesflow/experimental/networks/mlp/mlp.py diff --git a/bayesflow/experimental/networks/__init__.py b/bayesflow/experimental/networks/__init__.py index 7d0448aa..85cf7695 100644 --- a/bayesflow/experimental/networks/__init__.py +++ b/bayesflow/experimental/networks/__init__.py @@ -3,6 +3,7 @@ from .deep_set import DeepSet from .flow_matching import FlowMatching from .inference_network import InferenceNetwork +from .mlp import MLP from .resnet import ResNet from .set_transformer import SetTransformer from .summary_network import SummaryNetwork diff --git a/bayesflow/experimental/networks/mlp/__init__.py b/bayesflow/experimental/networks/mlp/__init__.py new file mode 100644 index 00000000..08c36c7d --- /dev/null +++ b/bayesflow/experimental/networks/mlp/__init__.py @@ -0,0 +1,2 @@ + +from .mlp import MLP diff --git a/bayesflow/experimental/networks/resnet/hidden_block.py b/bayesflow/experimental/networks/mlp/hidden_block.py similarity index 90% rename from bayesflow/experimental/networks/resnet/hidden_block.py rename to bayesflow/experimental/networks/mlp/hidden_block.py index e02f78f6..27c9b549 100644 --- a/bayesflow/experimental/networks/resnet/hidden_block.py +++ b/bayesflow/experimental/networks/mlp/hidden_block.py @@ -6,7 +6,7 @@ from bayesflow.experimental.types import Tensor -@register_keras_serializable(package="bayesflow.networks.resnet") +@register_keras_serializable(package="bayesflow.networks") class ConfigurableHiddenBlock(keras.layers.Layer): def __init__( self, @@ -38,8 +38,8 @@ def call(self, inputs: Tensor, training=False): return self.activation_fn(x) def build(self, input_shape): - super().build(input_shape) - self(keras.KerasTensor(input_shape)) + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) def get_config(self): config = super().get_config() diff --git a/bayesflow/experimental/networks/mlp/mlp.py b/bayesflow/experimental/networks/mlp/mlp.py new file mode 100644 index 00000000..6c34535d --- /dev/null +++ b/bayesflow/experimental/networks/mlp/mlp.py @@ -0,0 +1,80 @@ + +import keras +from keras import layers +from keras.saving import register_keras_serializable + +from bayesflow.experimental.types import Tensor +from .hidden_block import ConfigurableHiddenBlock + + +@register_keras_serializable(package="bayesflow.networks") +class MLP(keras.layers.Layer): + """ + Implements a simple configurable MLP with optional residual connections and dropout. + + If used in conjunction with a coupling net, a diffusion model, or a flow matching model, it assumes + that the input and conditions are already concatenated (i.e., this is a single-input model). + """ + + def __init__( + self, + num_hidden: int = 2, + hidden_dim: int = 256, + activation: str = "mish", + kernel_initializer: str = "he_normal", + residual: bool = True, + dropout: float = 0.05, + spectral_normalization: bool = False, + **kwargs + ): + """ + Creates an instance of a flexible and simple MLP with optional residual connections and dropout. + + Parameters: + ----------- + hidden_dim : int, optional, default: 256 + The dimensionality of the hidden layers + num_hidden : int, optional, default: 2 + The number of hidden layers (minimum: 1) + activation : string, optional, default: 'gelu' + The activation function of the dense layers + residual : bool, optional, default: True + Use residual connections in the internal layers. + spectral_normalization : bool, optional, default: False + Use spectral normalization for the network weights, which can make + the learned function smoother and hence more robust to perturbations. + dropout : float, optional, default: 0.05 + Dropout rate for the hidden layers in the internal layers. + """ + + super().__init__(**kwargs) + + self.res_blocks = keras.Sequential() + projector = layers.Dense( + units=hidden_dim, + kernel_initializer=kernel_initializer, + ) + if spectral_normalization: + projector = layers.SpectralNormalization(projector) + self.res_blocks.add(projector) + self.res_blocks.add(layers.Dropout(dropout)) + + for _ in range(num_hidden): + self.res_blocks.add( + ConfigurableHiddenBlock( + units=hidden_dim, + activation=activation, + kernel_initializer=kernel_initializer, + residual=residual, + dropout=dropout, + spectral_normalization=spectral_normalization + ) + ) + + def build(self, input_shape): + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) + + def call(self, inputs: Tensor, **kwargs): + return self.res_blocks(inputs, training=kwargs.get("training", False)) + diff --git a/bayesflow/experimental/networks/resnet/resnet.py b/bayesflow/experimental/networks/resnet/resnet.py index ae019505..e8ba7b3b 100644 --- a/bayesflow/experimental/networks/resnet/resnet.py +++ b/bayesflow/experimental/networks/resnet/resnet.py @@ -1,76 +1,29 @@ import keras -from keras import layers from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor -from .hidden_block import ConfigurableHiddenBlock -@register_keras_serializable(package="bayesflow.networks.resnet") -class ResNet(keras.layers.Layer): - """ - Implements a simple configurable MLP with optional residual connections and dropout. - - If used in conjunction with a coupling net, a diffusion model, or a flow matching model, it assumes - that the input and conditions are already concatenated (i.e., this is a single-input model). - """ - - def __init__( - self, - num_hidden: int = 2, - hidden_dim: int = 256, - activation: str = "mish", - kernel_initializer: str = "he_normal", - residual: bool = True, - dropout: float = 0.05, - spectral_normalization: bool = False, - **kwargs - ): - """ - Creates an instance of a flexible and simple MLP with optional residual connections and dropout. - - Parameters: - ----------- - hidden_dim : int, optional, default: 256 - The dimensionality of the hidden layers - num_hidden : int, optional, default: 2 - The number of hidden layers (minimum: 1) - activation : string, optional, default: 'gelu' - The activation function of the dense layers - residual : bool, optional, default: True - Use residual connections in the internal layers. - spectral_normalization : bool, optional, default: False - Use spectral normalization for the network weights, which can make - the learned function smoother and hence more robust to perturbations. - dropout : float, optional, default: 0.05 - Dropout rate for the hidden layers in the internal layers. - """ - +@register_keras_serializable(package="bayesflow.networks") +class ResNet(keras.Layer): + """ Implements a super-simple ResNet """ + def __init__(self, depth: int = 6, width: int = 2, activation: str = "gelu", **kwargs): super().__init__(**kwargs) - self.res_blocks = keras.Sequential() - projector = layers.Dense( - units=hidden_dim, - kernel_initializer=kernel_initializer, - ) - if spectral_normalization: - projector = layers.SpectralNormalization(projector) - self.res_blocks.add(projector) - self.res_blocks.add(layers.Dropout(dropout)) + self.input_layer = keras.layers.Dense(width) + self.output_layer = keras.layers.Dense(width) + self.hidden_layers = [keras.layers.Dense(width, activation) for _ in range(depth)] + + def build(self, input_shape): + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) - for _ in range(num_hidden): - self.res_blocks.add( - ConfigurableHiddenBlock( - units=hidden_dim, - activation=activation, - kernel_initializer=kernel_initializer, - residual=residual, - dropout=dropout, - spectral_normalization=spectral_normalization - ) - ) + def call(self, x: Tensor, **kwargs) -> Tensor: + x = self.input_layer(x) + for layer in self.hidden_layers: + x = x + layer(x) - def call(self, inputs: Tensor, **kwargs): - return self.res_blocks(inputs, training=kwargs.get("training", False)) + x = x + self.output_layer(x) + return x From 8bf60e1d9e4f0b75c02d9b6a0b50bc17b4c39572 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 19:04:15 +0200 Subject: [PATCH 24/24] slight update to tests --- tests/conftest.py | 49 ++++++++++++++++++++++++ tests/test_two_moons/conftest.py | 66 +++++++++++--------------------- tests/test_two_moons/test_fit.py | 2 +- 3 files changed, 73 insertions(+), 44 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..e55f97a9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,49 @@ + +import keras +import pytest + + +@pytest.fixture() +def amortizer(inference_network, summary_network): + from bayesflow.experimental.amortizers import Amortizer + + return Amortizer( + inference_network=inference_network, + summary_network=summary_network, + ) + + +@pytest.fixture() +def coupling_flow(): + from bayesflow.experimental.networks import CouplingFlow + return CouplingFlow() + + +@pytest.fixture() +def flow_matching(): + from bayesflow.experimental.networks import FlowMatching + return FlowMatching() + + +@pytest.fixture(params=["coupling_flow"]) +def inference_network(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["inference_network", "summary_network"]) +def network(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture() +def resnet(): + from bayesflow.experimental.networks import ResNet + return ResNet() + + +@pytest.fixture(params=[None]) +def summary_network(request): + if request.param is None: + return None + + return request.getfixturevalue(request.param) diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 40bf110d..c5badfc9 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -1,71 +1,51 @@ import math +import keras import pytest -from keras import ops as K -from keras import random as R import bayesflow.experimental as bf @pytest.fixture() -def context(): - class ContextPrior: - def sample(self, batch_shape): - r = R.normal(shape=batch_shape + (1,), mean=0.1, stddev=0.01) - alpha = R.uniform(shape=batch_shape + (1,), minval=-0.5 * math.pi, maxval=0.5 * math.pi) - - return dict(r=r, alpha=alpha) +def batch_size(): + return 32 @pytest.fixture() -def prior(): - class Prior: +def simulator(): + class Simulator: def sample(self, batch_shape): - theta = R.uniform(shape=batch_shape + (2,), minval=-1.0, maxval=1.0) - - return dict(theta=theta) + r = keras.random.normal(shape=batch_shape + (1,), mean=0.1, stddev=0.01) + alpha = keras.random.uniform(shape=batch_shape + (1,), minval=-0.5 * math.pi, maxval=0.5 * math.pi) + theta = keras.random.uniform(shape=batch_shape + (2,), minval=-1.0, maxval=1.0) - return Prior() - - -@pytest.fixture() -def likelihood(): - class Likelihood: - def sample(self, batch_shape, r, alpha, theta): - x1 = -K.abs(theta[0] + theta[1]) / K.sqrt(2.0) + r * K.cos(alpha) + 0.25 - x2 = (-theta[0] + theta[1]) / K.sqrt(2.0) + r * K.sin(alpha) - return dict(x=K.stack([x1, x2], axis=-1)) + x1 = -keras.ops.abs(theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.cos(alpha) + 0.25 + x2 = (-theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.sin(alpha) - return Likelihood() + x = keras.ops.stack([x1, x2], axis=-1) + return dict(r=r, alpha=alpha, theta=theta, x=x) -@pytest.fixture() -def joint_distribution(context, prior, likelihood): - return bf.simulation.JointDistribution(context, prior, likelihood) + return Simulator() @pytest.fixture() def dataset(joint_distribution): - # TODO: do not use hard-coded batch size return bf.datasets.OnlineDataset(joint_distribution, workers=4, use_multiprocessing=True, max_queue_size=16, batch_size=16) @pytest.fixture() def inference_network(): - return bf.networks.CouplingFlow.all_in_one( - subnet_builder="default", - target_dim=2, - num_layers=2, - transform="affine", - base_distribution="normal", - ) - - -@pytest.fixture() -def summary_network(): - return None + return bf.networks.CouplingFlow() @pytest.fixture() -def amortizer(inference_network, summary_network): - return bf.AmortizedPosterior(inference_network, summary_network) +def approximator(inference_network): + return bf.Approximator( + inference_network=inference_network, + inference_variables=["theta"], + inference_conditions=["x", "r", "alpha"], + summary_network=None, + summary_variables=[], + summary_conditions=[], + ) diff --git a/tests/test_two_moons/test_fit.py b/tests/test_two_moons/test_fit.py index 9c0d4397..d0687127 100644 --- a/tests/test_two_moons/test_fit.py +++ b/tests/test_two_moons/test_fit.py @@ -14,7 +14,7 @@ def test_compile(amortizer): def test_fit(amortizer, dataset): # TODO: verify the model learns something by comparing a metric before and after training amortizer.compile(optimizer="AdamW") - amortizer.fit(dataset, epochs=10, steps_per_epoch=10, batch_size=32) + amortizer.fit(dataset, epochs=10, steps_per_epoch=10) @pytest.mark.skip(reason="not implemented")