diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 623b8e20..e8ea0d30 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,11 @@ jobs: os: [ubuntu-latest, windows-latest] python-version: ["3.10", "3.11"] backend: ["jax", "numpy", "tensorflow", "torch"] + defaults: + run: + shell: bash -el {0} + env: + KERAS_BACKEND: ${{ matrix.backend }} steps: - name: Checkout code @@ -25,31 +30,25 @@ jobs: - name: Set up Conda uses: conda-incubator/setup-miniconda@v3 with: - activate-environment: bayesflow environment-file: environment.yaml python-version: ${{ matrix.python-version }} - auto-active-base: false - name: Install JAX if: ${{ matrix.backend == 'jax' }} run: | pip install -U "jax[cpu]" - export KERAS_BACKEND=jax - name: Install NumPy if: ${{ matrix.backend == 'numpy' }} run: | conda install numpy - export KERAS_BACKEND=numpy - name: Install Tensorflow if: ${{ matrix.backend == 'tensorflow' }} run: | pip install -U tensorflow - export KERAS_BACKEND=tensorflow - name: Install PyTorch if: ${{ matrix.backend == 'torch' }} run: | conda install pytorch torchvision torchaudio cpuonly -c pytorch - export KERAS_BACKEND=torch - name: Show Environment Info run: | @@ -59,6 +58,22 @@ jobs: conda config --show printenv | sort - - name: Run tests + - name: Run JAX Tests + if: ${{ matrix.backend == 'jax' }} + run: | + python -m pytest tests/ -n auto -v -m "not (numpy or tensorflow or torch)" + + - name: Run NumPy Tests + if: ${{ matrix.backend == 'numpy' }} + run: | + python -m pytest tests/ -n auto -v -m "not (jax or tensorflow or torch)" + + - name: Run TensorFlow Tests + if: ${{ matrix.backend == 'tensorflow' }} + run: | + python -m pytest tests/ -n auto -v -m "not (jax or numpy or torch)" + + - name: Run PyTorch Tests + if: ${{ matrix.backend == 'torch' }} run: | - python -m pytest tests/ -n auto -v + python -m pytest tests/ -n auto -v -m "not (jax or numpy or tensorflow)" diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index efa50cab..f5b86c10 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -1,20 +1,16 @@ from . import ( - amortizers, - datasets, - diagnostics, - networks, - simulation, + approximators, + datasets, + diagnostics, + distributions, + networks, + simulation, ) -from .amortizers import ( - AmortizedLikelihood, - AmortizedPosterior, - Amortizer, -) +from .approximators import Approximator -from .simulation import ( - distribution, - JointDistribution, +from .datasets import ( + OnlineDataset, + OfflineDataset, ) - diff --git a/bayesflow/experimental/amortizers/__init__.py b/bayesflow/experimental/amortizers/__init__.py deleted file mode 100644 index 9d1e36e7..00000000 --- a/bayesflow/experimental/amortizers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ - -from .amortized_likelihood import AmortizedLikelihood -from .amortized_posterior import AmortizedPosterior -from .amortizer import Amortizer -from .joint_amortizer import JointAmortizer -from .point_amortizer import AmortizedPointEstimator diff --git a/bayesflow/experimental/amortizers/amortized_likelihood.py b/bayesflow/experimental/amortizers/amortized_likelihood.py deleted file mode 100644 index 0ed8baa9..00000000 --- a/bayesflow/experimental/amortizers/amortized_likelihood.py +++ /dev/null @@ -1,23 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class AmortizedLikelihood(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=-1) - - def configure_observed_variables(self, data: dict): - # TODO: concatenate local context - return keras.ops.concatenate(list(data["parameters"].values()), axis=-1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - # TODO: concatenate global context - if summary_outputs is not None: - return summary_outputs - - return self.configure_observed_variables(data) - - def configure_summary_conditions(self, data: dict): - return None diff --git a/bayesflow/experimental/amortizers/amortized_posterior.py b/bayesflow/experimental/amortizers/amortized_posterior.py deleted file mode 100644 index 5fca58b1..00000000 --- a/bayesflow/experimental/amortizers/amortized_posterior.py +++ /dev/null @@ -1,21 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class AmortizedPosterior(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["parameters"].values()), axis=-1) - - def configure_observed_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=-1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - if summary_outputs is None: - return self.configure_observed_variables(data) - - return summary_outputs - - def configure_summary_conditions(self, data: dict): - return None diff --git a/bayesflow/experimental/amortizers/amortizer.py b/bayesflow/experimental/amortizers/amortizer.py deleted file mode 100644 index 710b924b..00000000 --- a/bayesflow/experimental/amortizers/amortizer.py +++ /dev/null @@ -1,174 +0,0 @@ - -import keras -from keras.saving import ( - deserialize_keras_object, - register_keras_serializable, - serialize_keras_object, -) - -from bayesflow.experimental.types import Tensor - -from .base_amortizer import BaseAmortizer - - -@register_keras_serializable(package="bayesflow.amortizers") -class Amortizer(BaseAmortizer): - def __init__(self, inferred_variables: list[str], observed_variables: list[str], inference_conditions: list[str] = None, summary_conditions: list[str] = None, **kwargs): - super().__init__(**kwargs) - self.inferred_variables = inferred_variables - self.observed_variables = observed_variables - self.inference_conditions = inference_conditions or [] - self.summary_conditions = summary_conditions or [] - - def configure_inferred_variables(self, data: dict[str, Tensor]) -> Tensor: - return keras.ops.concatenate([data[key] for key in self.inferred_variables]) - - def configure_observed_variables(self, data: dict[str, Tensor]) -> Tensor: - return keras.ops.concatenate([data[key] for key in self.observed_variables]) - - def configure_inference_conditions(self, data: dict[str, Tensor]) -> Tensor | None: - if not self.inference_conditions: - return None - - return keras.ops.concatenate([data[key] for key in self.inference_conditions]) - - def configure_summary_conditions(self, data: dict[str, Tensor]) -> Tensor | None: - if not self.summary_conditions: - return None - -<<<<<<< HEAD - if self.summary_network: - summary_conditions = self.configure_summary_conditions(x) - summary_outputs = self.summary_network(observed_variables, summary_conditions, **kwargs) - else: - summary_outputs = None - - inference_conditions = self.configure_inference_conditions(x, summary_outputs) - inference_outputs = self.inference_network(inferred_variables, inference_conditions, **kwargs) - - return { - "inference_outputs": inference_outputs, - "summary_outputs": summary_outputs, - } - - def compute_loss(self, x: dict = None, y: dict = None, y_pred: dict = None, **kwargs): - x = x or {} - y = y or {} - y_pred = y_pred or {} - - inferred_variables = self.configure_inferred_variables(x) - observed_variables = self.configure_observed_variables(x) - - if self.summary_network: - summary_conditions = self.configure_summary_conditions(x) - summary_loss = self.summary_network.compute_loss( - x=(observed_variables, summary_conditions), - y=y.get("summary_targets"), - y_pred=y_pred.get("summary_outputs") - ) - else: - summary_loss = keras.ops.zeros(()) - - inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs")) - inference_loss = self.inference_network.compute_loss( - x=inferred_variables, - conditions=inference_conditions, - **kwargs - ) - - return inference_loss + summary_loss - - def compute_metrics(self, x: dict, y: dict, y_pred: dict, **kwargs): - base_metrics = super().compute_metrics(x, y, y_pred, **kwargs) -<<<<<<< HEAD - - inferred_variables = self.configure_inferred_variables(x) - observed_variables = self.configure_observed_variables(x) - - if self.summary_network: - summary_conditions = self.configure_summary_conditions(x) - summary_metrics = self.summary_network.compute_metrics( - x=(observed_variables, summary_conditions), - y=y.get("summary_targets"), - y_pred=y_pred.get("summary_outputs") - ) - else: - summary_metrics = {} - - inference_conditions = self.configure_inference_conditions(x, y_pred.get("summary_outputs")) - inference_metrics = self.inference_network.compute_metrics( - x=(inferred_variables, inference_conditions), - y=y.get("inference_targets"), - y_pred=y_pred.get("inference_outputs") - ) - - summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} - inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} - - return base_metrics | summary_metrics | inference_metrics -======= - #TODO - add back metrics - return base_metrics ->>>>>>> streamlined-backend - - def sample(self, data: dict, num_samples: int, sample_summaries=False, **kwargs): - - # Configure everything -> inference conditions / summary conditions - configured_observables = None - - # Decide whether summaries are present or not / whether sumamry network is present or not - # ... - - return self.inference_network.sample(conditions=configured_observables, **kwargs) - - def log_prob(self, *args, **kwargs): - return self.inference_network.log_prob(*args, **kwargs) - - def configure_inferred_variables(self, data: dict): - """ - Return the inferred variables, given the data. - Inferred variables are passed as input to the inference network. - - This method must be efficient and deterministic. - Best practice is to prepare the output in dataset.__getitem__, - which is run in a worker process, and then simply fetch a key from the data dictionary here. - """ - raise NotImplementedError - - def configure_observed_variables(self, data: dict): - """ - Return the observed variables, given the data. - Observed variables are passed as input to the summary and/or inference networks. - - This method must be efficient and deterministic. - Best practice is to prepare the output in dataset.__getitem__, - which is run in a worker process, and then simply fetch a key from the data dictionary here. - """ - raise NotImplementedError - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - """ - Return the inference conditions, given the data. - Inference conditions are passed as conditional input to the inference network. - - If summary outputs are provided, they should be concatenated to the return value. - - This method must be efficient and deterministic. - Best practice is to prepare the output in dataset.__getitem__, - which is run in a worker process, and then simply fetch a key from the data dictionary here. - """ - raise NotImplementedError - - def configure_summary_conditions(self, data: dict): - """ - Return the summary conditions, given the data. - Summary conditions are passed as conditional input to the summary network. - - This method must be efficient and deterministic. - Best practice is to prepare the output in dataset.__getitem__, - which is run in a worker process, and then simply fetch a key from the data dictionary here. - """ - raise NotImplementedError -======= - return keras.ops.concatenate([data[key] for key in self.summary_conditions]) ->>>>>>> streamlined-backend diff --git a/bayesflow/experimental/amortizers/joint_amortizer.py b/bayesflow/experimental/amortizers/joint_amortizer.py deleted file mode 100644 index c7eecd69..00000000 --- a/bayesflow/experimental/amortizers/joint_amortizer.py +++ /dev/null @@ -1,31 +0,0 @@ - -import keras - -from .amortizer import Amortizer - - -class JointAmortizer(keras.Model): - def __init__(self, **amortizers: Amortizer): - super().__init__() - self.amortizers = amortizers - - def build(self, input_shape): - for amortizer in self.amortizers.values(): - amortizer.build(input_shape) - - def call(self, *args, **kwargs): - return {name: amortizer(*args, **kwargs) for name, amortizer in self.amortizers.items()} - - def compute_loss(self, *args, **kwargs): - losses = {name: amortizer.compute_loss(*args, **kwargs) for name, amortizer in self.amortizers.items()} - return keras.ops.mean(losses.values(), axis=0) - - def compute_metrics(self, *args, **kwargs): - metrics = {} - - for name, amortizer in self.amortizers.items(): - m = amortizer.compute_metrics(*args, **kwargs) - m = {f"{name}/{key}": value for key, value in m.items()} - metrics |= m - - return metrics diff --git a/bayesflow/experimental/amortizers/point_amortizer.py b/bayesflow/experimental/amortizers/point_amortizer.py deleted file mode 100644 index f9ea5484..00000000 --- a/bayesflow/experimental/amortizers/point_amortizer.py +++ /dev/null @@ -1,20 +0,0 @@ -import keras - -from .amortizer import Amortizer - - -class AmortizedPointEstimator(Amortizer): - def configure_inferred_variables(self, data: dict): - return keras.ops.concatenate(list(data["parameters"].values()), axis=1) - - def configure_observed_variables(self, data: dict): - return keras.ops.concatenate(list(data["observables"].values()), axis=1) - - def configure_inference_conditions(self, data: dict, summary_outputs=None): - if summary_outputs is None: - return self.configure_observed_variables(data) - - return summary_outputs - - def configure_summary_conditions(self, data: dict): - return None diff --git a/bayesflow/experimental/approximators/__init__.py b/bayesflow/experimental/approximators/__init__.py new file mode 100644 index 00000000..1a28ff47 --- /dev/null +++ b/bayesflow/experimental/approximators/__init__.py @@ -0,0 +1,3 @@ + +from .approximator import Approximator +from .joint_approximator import JointApproximator diff --git a/bayesflow/experimental/approximators/approximator.py b/bayesflow/experimental/approximators/approximator.py new file mode 100644 index 00000000..52fdea85 --- /dev/null +++ b/bayesflow/experimental/approximators/approximator.py @@ -0,0 +1,107 @@ + +import keras +from keras import ops +from keras.saving import ( + deserialize_keras_object, + register_keras_serializable, + serialize_keras_object, +) + +from bayesflow.experimental.types import Tensor + +from .base_approximator import BaseApproximator + + +@register_keras_serializable(package="bayesflow.approximators") +class Approximator(BaseApproximator): + def __init__( + self, + inference_variables: list[str], + inference_conditions: list[str] = None, + summary_variables: list[str] = None, + summary_conditions: list[str] = None, + **kwargs + ): + """ The main workhorse for learning amortized neural approximators for distributions arising + in inverse problems and Bayesian inference (e.g., posterior distributions, likelihoods, marginal + likelihoods). + + The complete semantics of this class allow for flexible estimation of the following distribution: + + Q(inference_variables | H(summary_variables; summary_conditions), inference_conditions), + + #TODO - math notation + + where all quantities to the right of the "given" symbol | are optional and H refers to the optional + summary /embedding network used to compress high-dimensional data into lower-dimensional summary + vectors. Some examples are provided below. + + Parameters + ---------- + inference_variables: list[str] + A list of variable names indicating the quantities to be inferred / learned by the approximator, + e.g., model parameters when approximating the Bayesian posterior or observables when approximating + a likelihood density. + inference_conditions: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + distribution over inference variables directly, that is, without passing through the summary network. + summary_variables: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + distribution over inference variables after passing through the summary network (i.e., undergoing a + learnable transformation / dimensionality reduction). For instance, non-vector quantities (e.g., + sets or time-series) in posterior inference will typically qualify as summary variables. In addition, + these quantities may involve learnable distributions on their own. + summary_conditions: list[str] + A list of variable names indicating quantities that will be used to condition (i.e., inform) the + optional summary network, e.g., when the summary network accepts further conditions that do not + conform to the semantics of summary variable (i.e., need not be embedded or their distribution + needs not be learned). + + #TODO add citations + + Examples + ------- + #TODO + """ + + super().__init__(**kwargs) + self.inference_variables = inference_variables + self.inference_conditions = inference_conditions or [] + self.summary_variables = summary_variables or [] + self.summary_conditions = summary_conditions or [] + + def configure_full_conditions( + self, + summary_outputs: Tensor | None, + inference_conditions: Tensor | None, + ) -> Tensor: + """ + Combine the (optional) inference conditions with the (optional) outputs + of the (optional) summary network. + """ + + if summary_outputs is None: + return inference_conditions + if inference_conditions is None: + return summary_outputs + return keras.ops.concatenate( + (summary_outputs, inference_conditions), axis=-1 + ) + + def configure_inference_variables(self, data: dict[str, Tensor]) -> Tensor: + return ops.concatenate([data[key] for key in self.inference_variables]) + + def configure_inference_conditions(self, data: dict[str, Tensor]) -> Tensor | None: + if not self.inference_conditions: + return None + return ops.concatenate([data[key] for key in self.inference_conditions]) + + def configure_summary_variables(self, data: dict[str, Tensor]) -> Tensor | None: + if not self.summary_variables: + return None + return ops.concatenate([data[key] for key in self.summary_variables]) + + def configure_summary_conditions(self, data: dict[str, Tensor]) -> Tensor | None: + if not self.summary_conditions: + return None + return ops.concatenate([data[key] for key in self.summary_conditions]) diff --git a/bayesflow/experimental/amortizers/base_amortizer.py b/bayesflow/experimental/approximators/base_approximator.py similarity index 62% rename from bayesflow/experimental/amortizers/base_amortizer.py rename to bayesflow/experimental/approximators/base_approximator.py index ed4feacd..70711a90 100644 --- a/bayesflow/experimental/amortizers/base_amortizer.py +++ b/bayesflow/experimental/approximators/base_approximator.py @@ -10,37 +10,25 @@ from bayesflow.experimental.networks import InferenceNetwork, SummaryNetwork -class BaseAmortizer(keras.Model): - def __init__(self, inference_network: InferenceNetwork, summary_network: SummaryNetwork = None, **kwargs): +class BaseApproximator(keras.Model): + def __init__( + self, + inference_network: InferenceNetwork, + summary_network: SummaryNetwork = None, + **kwargs + ): + super().__init__(**kwargs) self.inference_network = inference_network self.summary_network = summary_network def sample(self, num_samples: int, **kwargs) -> dict[str, Tensor]: # TODO - return self.inference_network.sample(num_samples, **kwargs) + return {} def log_prob(self, samples: dict[str, Tensor], **kwargs) -> Tensor: # TODO - samples = self.configure_inferred_variables(samples) - return self.inference_network.log_prob(samples, **kwargs) - - @classmethod - def from_config(cls, config: dict, custom_objects=None) -> "BaseAmortizer": - inference_network = deserialize_keras_object(config.pop("inference_network"), custom_objects=custom_objects) - summary_network = deserialize_keras_object(config.pop("summary_network"), custom_objects=custom_objects) - - return cls(inference_network, summary_network, **config) - - def get_config(self): - base_config = super().get_config() - - config = { - "inference_network": serialize_keras_object(self.inference_network), - "summary_network": serialize_keras_object(self.summary_network), - } - - return base_config | config + return {} def call(self, *, training=False, **data): if not training: @@ -55,53 +43,73 @@ def call(self, *, training=False, **data): return None def compute_loss(self, **data): - inferred_variables = self.configure_inferred_variables(data) - observed_variables = self.configure_observed_variables(data) + + # Configure dict outputs into tensors + inference_variables = self.configure_inference_variables(data) inference_conditions = self.configure_inference_conditions(data) + summary_variables = self.configure_summary_variables(data) summary_conditions = self.configure_summary_conditions(data) + # Obtain summary outputs and summary loss (if present) if self.summary_network: - summary_loss = self.summary_network.compute_loss( - observed_variables=observed_variables, - summary_conditions=summary_conditions, - ) + summary_outputs = self.summary_network(summary_variables, summary_conditions) + summary_loss = self.summary_network.compute_loss(summary_outputs) else: + summary_outputs = None summary_loss = keras.ops.zeros(()) + # Combine summary outputs and inference conditions + full_conditions = self.configure_full_conditions(summary_outputs, inference_conditions) + + # Compute inference loss inference_loss = self.inference_network.compute_loss( - inferred_variables=inferred_variables, - inference_conditions=inference_conditions, + targets=inference_variables, + conditions=full_conditions, ) return inference_loss + summary_loss def compute_metrics(self, **data): + #TODO base_metrics = super().compute_metrics(**data) + return base_metrics + + # inference_variables = self.configure_inference_variables(data) + # inference_conditions = self.configure_inference_conditions(data) + # summary_variables = self.configure_summary_variables(data) + # summary_conditions = self.configure_summary_conditions(data) + # + # if self.summary_network: + # summary_metrics = self.summary_network.compute_metrics( + # summary_variables=summary_variables, + # summary_conditions=summary_conditions, + # ) + # else: + # summary_metrics = {} + # + # inference_metrics = self.inference_network.compute_metrics( + # inference_variables=inference_variables, + # conditions=conditions, + # ) + # + # summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} + # inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} + + # return base_metrics | inference_metrics | summary_metrics + + def configure_full_conditions( + self, + summary_outputs: Tensor | None, + inference_conditions: Tensor | None, + ) -> Tensor: + """ + Combine the (optional) inference conditions with the (optional) outputs + of the (optional) summary network. + """ - inferred_variables = self.configure_inferred_variables(data) - observed_variables = self.configure_observed_variables(data) - summary_conditions = self.configure_summary_conditions(data) - inference_conditions = self.configure_inference_conditions(data) - - if self.summary_network: - summary_metrics = self.summary_network.compute_metrics( - observed_variables=observed_variables, - summary_conditions=summary_conditions, - ) - else: - summary_metrics = {} - - inference_metrics = self.inference_network.compute_metrics( - inferred_variables=inferred_variables, - inference_conditions=inference_conditions, - ) - - summary_metrics = {f"summary/{key}": value for key, value in summary_metrics.items()} - inference_metrics = {f"inference/{key}": value for key, value in inference_metrics.items()} - - return base_metrics | inference_metrics | summary_metrics + raise NotImplementedError - def configure_inferred_variables(self, data: dict) -> any: + def configure_inference_variables(self, data: dict) -> any: """ Return the inferred variables, given the data. Inferred variables are passed as input to the inference network. @@ -112,10 +120,12 @@ def configure_inferred_variables(self, data: dict) -> any: """ raise NotImplementedError - def configure_observed_variables(self, data: dict) -> any: + def configure_inference_conditions(self, data: dict) -> any: """ - Return the observed variables, given the data. - Observed variables are passed as input to the summary and/or inference networks. + Return the inference conditions, given the data. + Inference conditions are passed as conditional input to the inference network. + + If summary outputs are provided, they should be concatenated to the return value. This method must be efficient and deterministic. Best practice is to prepare the output in dataset.__getitem__, @@ -123,12 +133,10 @@ def configure_observed_variables(self, data: dict) -> any: """ raise NotImplementedError - def configure_inference_conditions(self, data: dict) -> any: + def configure_summary_variables(self, data: dict) -> any: """ - Return the inference conditions, given the data. - Inference conditions are passed as conditional input to the inference network. - - If summary outputs are provided, they should be concatenated to the return value. + Return the observed variables, given the data. + Observed variables are passed as input to the summary and/or inference networks. This method must be efficient and deterministic. Best practice is to prepare the output in dataset.__getitem__, @@ -146,3 +154,20 @@ def configure_summary_conditions(self, data: dict) -> any: which is run in a worker process, and then simply fetch a key from the data dictionary here. """ raise NotImplementedError + + @classmethod + def from_config(cls, config: dict, custom_objects=None) -> "BaseApproximator": + inference_network = deserialize_keras_object(config.pop("inference_network"), custom_objects=custom_objects) + summary_network = deserialize_keras_object(config.pop("summary_network"), custom_objects=custom_objects) + + return cls(inference_network, summary_network, **config) + + def get_config(self): + base_config = super().get_config() + + config = { + "inference_network": serialize_keras_object(self.inference_network), + "summary_network": serialize_keras_object(self.summary_network), + } + + return base_config | config diff --git a/bayesflow/experimental/approximators/joint_approximator.py b/bayesflow/experimental/approximators/joint_approximator.py new file mode 100644 index 00000000..034b155a --- /dev/null +++ b/bayesflow/experimental/approximators/joint_approximator.py @@ -0,0 +1,31 @@ + +import keras + +from .approximator import Approximator + + +class JointApproximator(keras.Model): + def __init__(self, **approximators: Approximator): + super().__init__() + self.approximators = approximators + + def build(self, input_shape): + for approximator in self.approximators.values(): + approximator.build(input_shape) + + def call(self, *args, **kwargs): + return {name: approximator(*args, **kwargs) for name, approximator in self.approximators.items()} + + def compute_loss(self, *args, **kwargs): + losses = {name: amortizer.compute_loss(*args, **kwargs) for name, amortizer in self.approximators.items()} + return keras.ops.mean(losses.values(), axis=0) + + def compute_metrics(self, *args, **kwargs): + metrics = {} + + for name, approximator in self.approximators.items(): + m = approximator.compute_metrics(*args, **kwargs) + m = {f"{name}/{key}": value for key, value in m.items()} + metrics |= m + + return metrics diff --git a/bayesflow/experimental/backend_approximators/__init__.py b/bayesflow/experimental/backend_approximators/__init__.py new file mode 100644 index 00000000..3e50858b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/__init__.py @@ -0,0 +1,2 @@ + +from .approximator import Approximator diff --git a/bayesflow/experimental/backend_approximators/approximator.py b/bayesflow/experimental/backend_approximators/approximator.py new file mode 100644 index 00000000..76112757 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/approximator.py @@ -0,0 +1,14 @@ + +import keras + +match keras.backend.backend(): + case "jax": + from .jax_approximator import JAXApproximator as Approximator + case "numpy": + from .numpy_approximator import NumpyApproximator as Approximator + case "tensorflow": + from .tensorflow_approximator import TensorFlowApproximator as Approximator + case "torch": + from .torch_approximator import TorchApproximator as Approximator + case other: + raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.") diff --git a/bayesflow/experimental/backend_approximators/base_approximator.py b/bayesflow/experimental/backend_approximators/base_approximator.py new file mode 100644 index 00000000..e17aa9ba --- /dev/null +++ b/bayesflow/experimental/backend_approximators/base_approximator.py @@ -0,0 +1,16 @@ + +import keras + +from bayesflow.experimental.types import Tensor + + +class BaseApproximator(keras.Model): + def train_step(self, data): + raise NotImplementedError + + # noinspection PyMethodOverriding + def compute_metrics(self, data: dict[str, Tensor], mode: str = "training") -> Tensor: + raise NotImplementedError + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError(f"Use compute_metrics()['loss'] instead.") diff --git a/bayesflow/experimental/backend_approximators/jax_approximator.py b/bayesflow/experimental/backend_approximators/jax_approximator.py new file mode 100644 index 00000000..376f597b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/jax_approximator.py @@ -0,0 +1,8 @@ + +import jax + +from .base_approximator import BaseApproximator + + +class JAXApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/numpy_approximator.py b/bayesflow/experimental/backend_approximators/numpy_approximator.py new file mode 100644 index 00000000..ebedd52c --- /dev/null +++ b/bayesflow/experimental/backend_approximators/numpy_approximator.py @@ -0,0 +1,8 @@ + +import numpy as np + +from .base_approximator import BaseApproximator + + +class NumpyApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/tensorflow_approximator.py b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py new file mode 100644 index 00000000..f6d10d41 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py @@ -0,0 +1,8 @@ + +import tensorflow as tf + +from .base_approximator import BaseApproximator + + +class TensorFlowApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/torch_approximator.py b/bayesflow/experimental/backend_approximators/torch_approximator.py new file mode 100644 index 00000000..78d1729d --- /dev/null +++ b/bayesflow/experimental/backend_approximators/torch_approximator.py @@ -0,0 +1,19 @@ + +import torch + + +from .base_approximator import BaseApproximator + + +class TorchApproximator(BaseApproximator): + def train_step(self, data): + with torch.enable_grad(): + metrics = self.compute_metrics(data, mode="training") + + loss = metrics.pop("loss") + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return metrics diff --git a/bayesflow/experimental/datasets/offline_dataset.py b/bayesflow/experimental/datasets/offline_dataset.py index 1d1b7702..9315a9e6 100644 --- a/bayesflow/experimental/datasets/offline_dataset.py +++ b/bayesflow/experimental/datasets/offline_dataset.py @@ -1,5 +1,6 @@ import keras +import math from bayesflow.experimental.utils import nested_getitem @@ -9,13 +10,12 @@ class OfflineDataset(keras.utils.PyDataset): A dataset that is pre-simulated and stored in memory. """ # TODO: fix - def __init__(self, data: dict, batch_size: int, batches_per_epoch: int, **kwargs): + def __init__(self, data: dict, batch_size: int, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size - self.batches_per_epoch = batches_per_epoch self.data = data - self.indices = keras.ops.arange(batch_size * batches_per_epoch) + self.indices = keras.ops.arange(len(data[next(iter(data.keys()))])) self.shuffle() @@ -27,7 +27,7 @@ def __getitem__(self, item: int) -> (dict, dict): return data, {} def __len__(self) -> int: - return self.batches_per_epoch + return math.ceil(len(self.indices) / self.batch_size) def on_epoch_end(self) -> None: self.shuffle() diff --git a/bayesflow/experimental/datasets/online_dataset.py b/bayesflow/experimental/datasets/online_dataset.py index a4dd830b..4a096b00 100644 --- a/bayesflow/experimental/datasets/online_dataset.py +++ b/bayesflow/experimental/datasets/online_dataset.py @@ -8,14 +8,14 @@ class OnlineDataset(keras.utils.PyDataset): """ A dataset that is generated on-the-fly. """ - def __init__(self, joint_distribution: JointDistribution, batch_size: int, **kwargs): + def __init__(self, distribution, batch_size: int, **kwargs): super().__init__(**kwargs) - self.joint_distribution = joint_distribution + self.distribution = distribution self.batch_size = batch_size def __getitem__(self, item: int) -> (dict, dict): """ Sample a batch of data from the joint distribution unconditionally """ - data = self.joint_distribution.sample((self.batch_size,)) + data = self.distribution.sample((self.batch_size,)) return data, {} @property diff --git a/bayesflow/experimental/networks/__init__.py b/bayesflow/experimental/networks/__init__.py index a02b50e3..2b5c85e9 100644 --- a/bayesflow/experimental/networks/__init__.py +++ b/bayesflow/experimental/networks/__init__.py @@ -3,6 +3,7 @@ from .deep_set import DeepSet from .flow_matching import FlowMatching from .inference_network import InferenceNetwork +from .mlp import MLP from .resnet import ResNet from .set_transformer import SetTransformer from .summary_network import SummaryNetwork \ No newline at end of file diff --git a/bayesflow/experimental/networks/coupling_flow/actnorm.py b/bayesflow/experimental/networks/coupling_flow/actnorm.py index aa7b995c..398f7ae1 100644 --- a/bayesflow/experimental/networks/coupling_flow/actnorm.py +++ b/bayesflow/experimental/networks/coupling_flow/actnorm.py @@ -3,9 +3,7 @@ >>>>>>> streamlined-backend from keras import ops -from keras.saving import ( - register_keras_serializable, -) +from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor from .invertible_layer import InvertibleLayer diff --git a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py index 51f94fec..86713af3 100644 --- a/bayesflow/experimental/networks/coupling_flow/coupling_flow.py +++ b/bayesflow/experimental/networks/coupling_flow/coupling_flow.py @@ -2,23 +2,12 @@ from typing import Tuple, Union import keras -<<<<<<< HEAD -from keras.saving import ( - register_keras_serializable, -) - -from bayesflow.experimental.types import Tensor -from .actnorm import ActNorm -from .couplings import DualCoupling -======= from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor -from bayesflow.experimental.utils import keras_kwargs +from bayesflow.experimental.utils import find_permutation from .actnorm import ActNorm from .couplings import DualCoupling -from .permutations import OrthogonalPermutation, RandomPermutation, Swap ->>>>>>> streamlined-backend from ..inference_network import InferenceNetwork @@ -47,86 +36,41 @@ class CouplingFlow(InferenceNetwork): arXiv preprint arXiv:2006.06599. """ def __init__( -<<<<<<< HEAD - self, - depth: int = 6, - subnet: str = "resnet", - transform: str = "affine", - use_actnorm: bool = True, **kwargs - ): - super().__init__(**kwargs) - - self._layers = [] - for _ in range(depth): - if use_actnorm: - self._layers.append(ActNorm()) - self._layers.append(DualCoupling(subnet, transform)) - - def build(self, input_shape): - super().build(input_shape) - self.call(keras.KerasTensor(input_shape)) - - def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: - if inverse: - return self._inverse(xz, **kwargs) - return self._forward(xz, **kwargs) - - def _forward(self, x: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: - z = x - log_det = 0.0 - for layer in self._layers: - z, det = layer(z, inverse=False, **kwargs) - log_det += det - - if jacobian: - return z, log_det - return z - - def _inverse(self, z: Tensor, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: - x = z - log_det = 0.0 - for layer in reversed(self._layers): - x, det = layer(x, inverse=True, **kwargs) - log_det += det - - if jacobian: - return x, log_det - return x - - def compute_loss(self, x: Tensor = None, **kwargs): - z, log_det = self(x, inverse=False, jacobian=True, **kwargs) -======= self, depth: int = 6, subnet: str = "resnet", transform: str = "affine", - permutation: str = "random", + permutation: str | None = None, use_actnorm: bool = True, **kwargs ): - """TODO""" + # TODO - propagate optional keyword arguments to find_network and ResNet respectively + super().__init__(**kwargs) - super().__init__(**keras_kwargs(kwargs)) + self.depth = depth - self._layers = [] + self.invertible_layers = [] for i in range(depth): if use_actnorm: - self._layers.append(ActNorm()) - self._layers.append(DualCoupling(subnet, transform, **kwargs)) - if permutation.lower() == "random": - self._layers.append(RandomPermutation()) - elif permutation.lower() == "swap": - self._layers.append(Swap()) - elif permutation.lower() == "learnable": - self._layers.append(OrthogonalPermutation()) + self.invertible_layers.append(ActNorm(name=f"ActNorm{i}")) + + self.invertible_layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}")) + + if (p := find_permutation(permutation, name=f"Permutation{i}")) is not None: + self.invertible_layers.append(p) # noinspection PyMethodOverriding def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) + + xz = keras.KerasTensor(xz_shape) if conditions_shape is None: - self.call(keras.KerasTensor(xz_shape)) + conditions = None else: - self.call(keras.KerasTensor(xz_shape), conditions=keras.KerasTensor(conditions_shape)) + conditions = keras.KerasTensor(conditions_shape) + + # build nested layers with forward pass + self.call(xz, conditions=conditions) def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: if inverse: @@ -136,11 +80,8 @@ def call(self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **k def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: z = x log_det = keras.ops.zeros(keras.ops.shape(x)[:-1]) - for layer in self._layers: - if isinstance(layer, DualCoupling): - z, det = layer(z, conditions=conditions, inverse=False, **kwargs) - else: - z, det = layer(z, inverse=False, **kwargs) + for layer in self.invertible_layers: + z, det = layer(z, conditions=conditions, inverse=False, **kwargs) log_det += det if jacobian: @@ -150,20 +91,16 @@ def _forward(self, x: Tensor, conditions: Tensor = None, jacobian: bool = False, def _inverse(self, z: Tensor, conditions: Tensor = None, jacobian: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = z log_det = keras.ops.zeros(keras.ops.shape(z)[:-1]) - for layer in reversed(self._layers): - if isinstance(layer, DualCoupling): - x, det = layer(x, conditions=conditions, inverse=True, **kwargs) - else: - x, det = layer(x, inverse=True, **kwargs) + for layer in reversed(self.invertible_layers): + x, det = layer(x, conditions=conditions, inverse=True, **kwargs) log_det += det if jacobian: return x, log_det return x - def compute_loss(self, x: Tensor = None, conditions: Tensor = None, **kwargs): - z, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs) ->>>>>>> streamlined-backend + def compute_loss(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: + z, log_det = self(inference_variables, conditions=inference_conditions, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(z) nll = -keras.ops.mean(log_prob + log_det) diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py index 7292bcd2..f1c73afe 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/dual_coupling.py @@ -29,8 +29,18 @@ def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs): >>>>>>> streamlined-backend self.pivot = None - def build(self, input_shape): - self.pivot = input_shape[-1] // 2 + # noinspection PyMethodOverriding + def build(self, xz_shape, conditions_shape=None): + self.pivot = xz_shape[-1] // 2 + + xz = keras.KerasTensor(xz_shape) + if conditions_shape is None: + conditions = None + else: + conditions = keras.KerasTensor(conditions_shape) + + # build nested layers with forward pass + self.call(xz, conditions=conditions) <<<<<<< HEAD def call(self, xz: Tensor, conditions: any = None, inverse: bool = False) -> (Tensor, Tensor): diff --git a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py index 7d6a08c9..82027d4f 100644 --- a/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py +++ b/bayesflow/experimental/networks/coupling_flow/couplings/single_coupling.py @@ -17,23 +17,33 @@ class SingleCoupling(InvertibleLayer): """ def __init__( self, - network: str = "resnet", + subnet: str = "resnet", transform: str = "affine", - output_layer_kernel_init: str = "zeros", **kwargs ): super().__init__(**keras_kwargs(kwargs)) - self.output_projector = keras.layers.Dense( - units=None, - kernel_initializer=output_layer_kernel_init, - ) - self.network = find_network(network, **kwargs.get("subnet_kwargs", {})) + + self.network = find_network(subnet, **kwargs.get("subnet_kwargs", {})) self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {})) + output_projector_kwargs = kwargs.get("output_projector_kwargs", {}) + output_projector_kwargs.setdefault("kernel_initializer", "zeros") + self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs) + # noinspection PyMethodOverriding - def build(self, x1_shape, x2_shape): + def build(self, x1_shape, x2_shape, conditions_shape=None): self.output_projector.units = self.transform.params_per_dim * x2_shape[-1] + x1 = keras.KerasTensor(x1_shape) + x2 = keras.KerasTensor(x2_shape) + if conditions_shape is None: + conditions = None + else: + conditions = keras.KerasTensor(conditions_shape) + + # build nested layers with forward pass + self.call(x1, x2, conditions=conditions) + def call(self, x1: Tensor, x2: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs) -> ((Tensor, Tensor), Tensor): if inverse: return self._inverse(x1, x2, conditions=conditions, **kwargs) diff --git a/bayesflow/experimental/networks/inference_network.py b/bayesflow/experimental/networks/inference_network.py index 3a144a41..9b398f36 100644 --- a/bayesflow/experimental/networks/inference_network.py +++ b/bayesflow/experimental/networks/inference_network.py @@ -3,13 +3,8 @@ import keras -<<<<<<< HEAD -from bayesflow.experimental.distributions import find_distribution -from bayesflow.experimental.types import Tensor -======= from bayesflow.experimental.types import Tensor from bayesflow.experimental.utils import find_distribution ->>>>>>> streamlined-backend class InferenceNetwork(keras.Layer): @@ -32,26 +27,18 @@ def _forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: def _inverse(self, z: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]: raise NotImplementedError -<<<<<<< HEAD - def sample(self, num_samples: int, **kwargs) -> Tensor: - samples = self.base_distribution.sample((num_samples,)) - return self(samples, inverse=True, jacobian=False, **kwargs) - - def log_prob(self, x: Tensor, **kwargs) -> Tensor: - samples, log_det = self(x, inverse=False, jacobian=True, **kwargs) -======= def sample(self, num_samples: int, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample((num_samples,)) - return self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs) + samples = self(samples, conditions=conditions, inverse=True, jacobian=False, **kwargs) + return samples - def log_prob(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: - samples, log_det = self(x, conditions=conditions, inverse=False, jacobian=True, **kwargs) ->>>>>>> streamlined-backend + def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: + samples, log_det = self(samples, conditions=conditions, inverse=False, jacobian=True, **kwargs) log_prob = self.base_distribution.log_prob(samples) return log_prob + log_det - def compute_loss(self, inferred_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: + def compute_loss(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> Tensor: raise NotImplementedError - def compute_metrics(self, inferred_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> dict: + def compute_metrics(self, inference_variables: Tensor, inference_conditions: Tensor = None, **kwargs) -> dict: return {} diff --git a/bayesflow/experimental/networks/mlp/__init__.py b/bayesflow/experimental/networks/mlp/__init__.py new file mode 100644 index 00000000..08c36c7d --- /dev/null +++ b/bayesflow/experimental/networks/mlp/__init__.py @@ -0,0 +1,2 @@ + +from .mlp import MLP diff --git a/bayesflow/experimental/networks/resnet/hidden_block.py b/bayesflow/experimental/networks/mlp/hidden_block.py similarity index 90% rename from bayesflow/experimental/networks/resnet/hidden_block.py rename to bayesflow/experimental/networks/mlp/hidden_block.py index e02f78f6..27c9b549 100644 --- a/bayesflow/experimental/networks/resnet/hidden_block.py +++ b/bayesflow/experimental/networks/mlp/hidden_block.py @@ -6,7 +6,7 @@ from bayesflow.experimental.types import Tensor -@register_keras_serializable(package="bayesflow.networks.resnet") +@register_keras_serializable(package="bayesflow.networks") class ConfigurableHiddenBlock(keras.layers.Layer): def __init__( self, @@ -38,8 +38,8 @@ def call(self, inputs: Tensor, training=False): return self.activation_fn(x) def build(self, input_shape): - super().build(input_shape) - self(keras.KerasTensor(input_shape)) + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) def get_config(self): config = super().get_config() diff --git a/bayesflow/experimental/networks/mlp/mlp.py b/bayesflow/experimental/networks/mlp/mlp.py new file mode 100644 index 00000000..6c34535d --- /dev/null +++ b/bayesflow/experimental/networks/mlp/mlp.py @@ -0,0 +1,80 @@ + +import keras +from keras import layers +from keras.saving import register_keras_serializable + +from bayesflow.experimental.types import Tensor +from .hidden_block import ConfigurableHiddenBlock + + +@register_keras_serializable(package="bayesflow.networks") +class MLP(keras.layers.Layer): + """ + Implements a simple configurable MLP with optional residual connections and dropout. + + If used in conjunction with a coupling net, a diffusion model, or a flow matching model, it assumes + that the input and conditions are already concatenated (i.e., this is a single-input model). + """ + + def __init__( + self, + num_hidden: int = 2, + hidden_dim: int = 256, + activation: str = "mish", + kernel_initializer: str = "he_normal", + residual: bool = True, + dropout: float = 0.05, + spectral_normalization: bool = False, + **kwargs + ): + """ + Creates an instance of a flexible and simple MLP with optional residual connections and dropout. + + Parameters: + ----------- + hidden_dim : int, optional, default: 256 + The dimensionality of the hidden layers + num_hidden : int, optional, default: 2 + The number of hidden layers (minimum: 1) + activation : string, optional, default: 'gelu' + The activation function of the dense layers + residual : bool, optional, default: True + Use residual connections in the internal layers. + spectral_normalization : bool, optional, default: False + Use spectral normalization for the network weights, which can make + the learned function smoother and hence more robust to perturbations. + dropout : float, optional, default: 0.05 + Dropout rate for the hidden layers in the internal layers. + """ + + super().__init__(**kwargs) + + self.res_blocks = keras.Sequential() + projector = layers.Dense( + units=hidden_dim, + kernel_initializer=kernel_initializer, + ) + if spectral_normalization: + projector = layers.SpectralNormalization(projector) + self.res_blocks.add(projector) + self.res_blocks.add(layers.Dropout(dropout)) + + for _ in range(num_hidden): + self.res_blocks.add( + ConfigurableHiddenBlock( + units=hidden_dim, + activation=activation, + kernel_initializer=kernel_initializer, + residual=residual, + dropout=dropout, + spectral_normalization=spectral_normalization + ) + ) + + def build(self, input_shape): + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) + + def call(self, inputs: Tensor, **kwargs): + return self.res_blocks(inputs, training=kwargs.get("training", False)) + diff --git a/bayesflow/experimental/networks/resnet/resnet.py b/bayesflow/experimental/networks/resnet/resnet.py index ae019505..e8ba7b3b 100644 --- a/bayesflow/experimental/networks/resnet/resnet.py +++ b/bayesflow/experimental/networks/resnet/resnet.py @@ -1,76 +1,29 @@ import keras -from keras import layers from keras.saving import register_keras_serializable from bayesflow.experimental.types import Tensor -from .hidden_block import ConfigurableHiddenBlock -@register_keras_serializable(package="bayesflow.networks.resnet") -class ResNet(keras.layers.Layer): - """ - Implements a simple configurable MLP with optional residual connections and dropout. - - If used in conjunction with a coupling net, a diffusion model, or a flow matching model, it assumes - that the input and conditions are already concatenated (i.e., this is a single-input model). - """ - - def __init__( - self, - num_hidden: int = 2, - hidden_dim: int = 256, - activation: str = "mish", - kernel_initializer: str = "he_normal", - residual: bool = True, - dropout: float = 0.05, - spectral_normalization: bool = False, - **kwargs - ): - """ - Creates an instance of a flexible and simple MLP with optional residual connections and dropout. - - Parameters: - ----------- - hidden_dim : int, optional, default: 256 - The dimensionality of the hidden layers - num_hidden : int, optional, default: 2 - The number of hidden layers (minimum: 1) - activation : string, optional, default: 'gelu' - The activation function of the dense layers - residual : bool, optional, default: True - Use residual connections in the internal layers. - spectral_normalization : bool, optional, default: False - Use spectral normalization for the network weights, which can make - the learned function smoother and hence more robust to perturbations. - dropout : float, optional, default: 0.05 - Dropout rate for the hidden layers in the internal layers. - """ - +@register_keras_serializable(package="bayesflow.networks") +class ResNet(keras.Layer): + """ Implements a super-simple ResNet """ + def __init__(self, depth: int = 6, width: int = 2, activation: str = "gelu", **kwargs): super().__init__(**kwargs) - self.res_blocks = keras.Sequential() - projector = layers.Dense( - units=hidden_dim, - kernel_initializer=kernel_initializer, - ) - if spectral_normalization: - projector = layers.SpectralNormalization(projector) - self.res_blocks.add(projector) - self.res_blocks.add(layers.Dropout(dropout)) + self.input_layer = keras.layers.Dense(width) + self.output_layer = keras.layers.Dense(width) + self.hidden_layers = [keras.layers.Dense(width, activation) for _ in range(depth)] + + def build(self, input_shape): + # build nested layers with forward pass + self.call(keras.KerasTensor(input_shape)) - for _ in range(num_hidden): - self.res_blocks.add( - ConfigurableHiddenBlock( - units=hidden_dim, - activation=activation, - kernel_initializer=kernel_initializer, - residual=residual, - dropout=dropout, - spectral_normalization=spectral_normalization - ) - ) + def call(self, x: Tensor, **kwargs) -> Tensor: + x = self.input_layer(x) + for layer in self.hidden_layers: + x = x + layer(x) - def call(self, inputs: Tensor, **kwargs): - return self.res_blocks(inputs, training=kwargs.get("training", False)) + x = x + self.output_layer(x) + return x diff --git a/bayesflow/experimental/networks/summary_network.py b/bayesflow/experimental/networks/summary_network.py index e51bfe83..90a86bd5 100644 --- a/bayesflow/experimental/networks/summary_network.py +++ b/bayesflow/experimental/networks/summary_network.py @@ -5,8 +5,8 @@ class SummaryNetwork(keras.Layer): - def compute_loss(self, observed_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> Tensor: + def compute_loss(self, summary_outputs: Tensor, **kwargs) -> Tensor: return keras.ops.zeros(()) - def compute_metrics(self, observed_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> dict: + def compute_metrics(self, summary_variables: Tensor, summary_conditions: Tensor = None, **kwargs) -> dict: return {} diff --git a/bayesflow/experimental/utils/__init__.py b/bayesflow/experimental/utils/__init__.py index 9b2dfa56..6f46f02b 100644 --- a/bayesflow/experimental/utils/__init__.py +++ b/bayesflow/experimental/utils/__init__.py @@ -1,23 +1,8 @@ from .dictutils import nested_getitem, keras_kwargs -from .finders import find_distribution, find_network, find_pooling - -from .computils import ( - expected_calibration_error, - simultaneous_ecdf_bands, - get_coverage_probs +from .dispatch import ( + find_distribution, + find_network, + find_permutation, + find_pooling, ) - -from .plotutils import ( - get_count_and_names, - preprocess, - postprocess, - check_posterior_prior_shapes, - initialize_figure, - collapse_axes, - configure_layout, - add_labels, - add_xlabels, - add_ylabels, - remove_unused_axes -) \ No newline at end of file diff --git a/bayesflow/experimental/utils/dispatch/__init__.py b/bayesflow/experimental/utils/dispatch/__init__.py new file mode 100644 index 00000000..864f6547 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/__init__.py @@ -0,0 +1,5 @@ + +from .find_distribution import find_distribution +from .find_network import find_network +from .find_permutation import find_permutation +from .find_pooling import find_pooling diff --git a/bayesflow/experimental/utils/dispatch/find_distribution.py b/bayesflow/experimental/utils/dispatch/find_distribution.py new file mode 100644 index 00000000..2c87a28a --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_distribution.py @@ -0,0 +1,24 @@ + +from functools import singledispatch + + +@singledispatch +def find_distribution(arg, **kwargs): + raise TypeError(f"Cannot infer distribution from {arg!r}.") + + +@find_distribution.register +def _(name: str, **kwargs): + match name.lower(): + case "normal": + from bayesflow.experimental.distributions import DiagonalNormal + distribution = DiagonalNormal(**kwargs) + case other: + raise ValueError(f"Unsupported distribution name '{other}'.") + + return distribution + + +@find_distribution.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) diff --git a/bayesflow/experimental/utils/dispatch/find_network.py b/bayesflow/experimental/utils/dispatch/find_network.py new file mode 100644 index 00000000..12d58952 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_network.py @@ -0,0 +1,31 @@ + +import keras + +from functools import singledispatch + + +@singledispatch +def find_network(arg, **kwargs): + raise TypeError(f"Cannot infer network from {arg!r}.") + + +@find_network.register +def _(name: str, **kwargs): + match name.lower(): + case "resnet": + from bayesflow.experimental.networks import ResNet + network = ResNet(**kwargs) + case other: + raise ValueError(f"Unsupported network name: '{other}'.") + + return network + + +@find_network.register +def _(network: keras.Layer): + return network + + +@find_network.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) diff --git a/bayesflow/experimental/utils/dispatch/find_permutation.py b/bayesflow/experimental/utils/dispatch/find_permutation.py new file mode 100644 index 00000000..c01864b9 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_permutation.py @@ -0,0 +1,32 @@ + +import keras +from functools import singledispatch + + +@singledispatch +def find_permutation(arg, **kwargs): + raise TypeError(f"Cannot infer permutation from {arg!r}.") + + +@find_permutation.register +def _(name: str, **kwargs): + match name.lower(): + case "random": + from bayesflow.experimental.networks.coupling_flow.permutations import RandomPermutation + return RandomPermutation(**kwargs) + case "swap": + from bayesflow.experimental.networks.coupling_flow.permutations import Swap + return Swap(**kwargs) + case "learnable" | "orthogonal": + from bayesflow.experimental.networks.coupling_flow.permutations import OrthogonalPermutation + return OrthogonalPermutation(**kwargs) + + +@find_permutation.register +def _(permutation: keras.Layer, **kwargs): + return permutation + + +@find_permutation.register +def _(none: None, **kwargs): + return None diff --git a/bayesflow/experimental/utils/dispatch/find_pooling.py b/bayesflow/experimental/utils/dispatch/find_pooling.py new file mode 100644 index 00000000..398dbbb4 --- /dev/null +++ b/bayesflow/experimental/utils/dispatch/find_pooling.py @@ -0,0 +1,38 @@ + +import keras + +from functools import singledispatch + + +@singledispatch +def find_pooling(arg, **kwargs): + raise TypeError(f"Cannot infer pooling from {arg!r}.") + + +@find_pooling.register +def _(name: str, **kwargs): + match name.lower(): + case "mean" | "avg" | "average": + pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2)) + case "max": + pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2)) + case "min": + pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) + case "learnable" | "pma" | "attention": + from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention + pooling = PoolingByMultiheadAttention(**kwargs) + case other: + raise ValueError(f"Unsupported pooling name: '{other}'.") + + return pooling + + +@find_pooling.register +def _(constructor: type, **kwargs): + return constructor(**kwargs) + + +@find_pooling.register +def _(pooling: keras.Layer): + return pooling + diff --git a/bayesflow/experimental/utils/finders.py b/bayesflow/experimental/utils/finders.py deleted file mode 100644 index ace72f4f..00000000 --- a/bayesflow/experimental/utils/finders.py +++ /dev/null @@ -1,65 +0,0 @@ - -from functools import partial - -import keras - - -def find_distribution(distribution: str | type, **kwargs): - # TODO -> return type - match distribution: - case str() as name: - match name.lower(): - case "normal": - from bayesflow.experimental.distributions import DiagonalNormal - distribution = DiagonalNormal(**kwargs) - case other: - raise ValueError(f"Unsupported distribution name: '{other}'.") - case type() as constructor: - distribution = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer distribution from {other!r}.") - - return distribution - - -def find_network(network: str | keras.Layer | type, **kwargs) -> keras.Layer: - match network: - case str() as name: - match name.lower(): - case "resnet": - from bayesflow.experimental.networks import ResNet - network = ResNet(**kwargs) - case other: - raise ValueError(f"Unsupported network name: '{other}'.") - case keras.Layer() as network: - pass - case type() as constructor: - network = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer network from {other!r}.") - - return network - - -def find_pooling(pooling: str | keras.Layer | type, **kwargs) -> keras.Layer: - match pooling: - case str() as name: - match name.lower(): - case "mean" | "avg": - pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2)) - case "max": - pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2)) - case "min": - pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) - case "learnable" | "pma": - from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention - pooling = PoolingByMultiheadAttention(**kwargs) - case other: - raise ValueError(f"Unsupported pooling type: '{other}'.") - case keras.Layer() as pooling: - pass - case type() as constructor: - pooling = constructor(**kwargs) - case other: - raise TypeError(f"Cannot infer pooling type from {other!r}.") - return pooling diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..e55f97a9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,49 @@ + +import keras +import pytest + + +@pytest.fixture() +def amortizer(inference_network, summary_network): + from bayesflow.experimental.amortizers import Amortizer + + return Amortizer( + inference_network=inference_network, + summary_network=summary_network, + ) + + +@pytest.fixture() +def coupling_flow(): + from bayesflow.experimental.networks import CouplingFlow + return CouplingFlow() + + +@pytest.fixture() +def flow_matching(): + from bayesflow.experimental.networks import FlowMatching + return FlowMatching() + + +@pytest.fixture(params=["coupling_flow"]) +def inference_network(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["inference_network", "summary_network"]) +def network(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture() +def resnet(): + from bayesflow.experimental.networks import ResNet + return ResNet() + + +@pytest.fixture(params=[None]) +def summary_network(request): + if request.param is None: + return None + + return request.getfixturevalue(request.param) diff --git a/tests/test_amortizers/__init__.py b/tests/test_approximators/__init__.py similarity index 100% rename from tests/test_amortizers/__init__.py rename to tests/test_approximators/__init__.py diff --git a/tests/test_amortizers/conftest.py b/tests/test_approximators/conftest.py similarity index 77% rename from tests/test_amortizers/conftest.py rename to tests/test_approximators/conftest.py index 09b0e0d4..d5dd6f70 100644 --- a/tests/test_amortizers/conftest.py +++ b/tests/test_approximators/conftest.py @@ -19,10 +19,16 @@ def inference_network(): return network -@pytest.fixture(params=[bf.AmortizedPosterior, bf.AmortizedLikelihood]) -def amortizer(request, inference_network, summary_network): - Amortizer = request.param - return Amortizer(inference_network, summary_network) +@pytest.fixture() +def approximator(inference_network, summary_network): + return bf.Approximator( + inference_network=inference_network, + summary_network=summary_network, + inference_variables=[], + inference_conditions=[], + summary_variables=[], + summary_conditions=[], + ) @pytest.fixture() diff --git a/tests/test_amortizers/test_fit.py b/tests/test_approximators/test_fit.py similarity index 100% rename from tests/test_amortizers/test_fit.py rename to tests/test_approximators/test_fit.py diff --git a/tests/test_datasets/conftest.py b/tests/test_datasets/conftest.py index f882b5a1..8be74c18 100644 --- a/tests/test_datasets/conftest.py +++ b/tests/test_datasets/conftest.py @@ -1,33 +1,90 @@ import keras -import keras.random import pytest -import bayesflow.experimental as bf + +@pytest.fixture() +def batch_size(): + return 16 -# TODO: do this last when the implementation of multiprocessing dataloading pipelines is done +@pytest.fixture(params=["online_dataset", "offline_dataset"]) +def dataset(request, online_dataset, offline_dataset): + return request.getfixturevalue(request.param) @pytest.fixture() -def joint_distribution(): - class JointDistribution: - def sample(self, batch_shape): - return dict(x=keras.random.normal(batch_shape + (2,))) +def model(): + class Model(keras.Model): + def call(self, *args, **kwargs): + pass + + def compute_loss(self, **kwargs): + return keras.ops.zeros(()) - def log_prob(self, x): - raise NotImplementedError + model = Model() + model.compile(optimizer=None) - return JointDistribution() + return model @pytest.fixture() -def online_dataset(joint_distribution): - return bf.datasets.OnlineDataset(joint_distribution, batch_size=1) +def offline_dataset(simulator, batch_size, workers, use_multiprocessing): + from bayesflow.experimental import OfflineDataset + + # TODO: there is a bug in keras where if len(dataset) == 1 batch + # fit will error because no logs are generated + # the single batch is then skipped entirely + data = simulator.sample((batch_size * 2,)) + return OfflineDataset(data, batch_size=batch_size, workers=workers, use_multiprocessing=use_multiprocessing) + + +@pytest.fixture() +def online_dataset(simulator, batch_size, workers, use_multiprocessing): + from bayesflow.experimental import OnlineDataset + + return OnlineDataset(simulator, batch_size=batch_size, workers=workers, use_multiprocessing=use_multiprocessing) + + +# needs to be global for pickle to work + +from bayesflow.experimental.simulation.decorators.distribution_decorator import DistributionDecorator as make_distribution + + +class Simulator: + def sample(self, batch_shape): + return dict(x=keras.random.normal(batch_shape + (2,))) + + +@make_distribution(is_batched=True) +def batched_decorated_simulator(batch_shape): + return dict(x=keras.random.normal(batch_shape + (2,))) + + +@make_distribution(is_batched=False) +def unbatched_decorated_simulator(): + return dict(x=keras.random.normal((2,))) + + +@pytest.fixture(params=["class", "batched_decorator", "unbatched_decorator"]) +def simulator(request): + if request.param == "class": + simulator = Simulator() + elif request.param == "batched_decorator": + simulator = batched_decorated_simulator + elif request.param == "unbatched_decorator": + simulator = unbatched_decorated_simulator + else: + raise NotImplementedError + + return simulator + + +@pytest.fixture(params=[True, False]) +def use_multiprocessing(request): + return request.param -# @pytest.fixture() -# def offline_dataset(tmp_path, joint_distribution): -# samples = joint_distribution.sample((32,)) -# ... # ? -# return bf.datasets.OfflineDataset(joint_distribution, batch_size=1, batches_per_epoch=None) +@pytest.fixture(params=[1, 2]) +def workers(request): + return request.param diff --git a/tests/test_datasets/test_datasets.py b/tests/test_datasets/test_datasets.py new file mode 100644 index 00000000..f1fe16bc --- /dev/null +++ b/tests/test_datasets/test_datasets.py @@ -0,0 +1,35 @@ + +import keras +import pickle +import pytest + + +@pytest.mark.skip(reason="WIP") +def test_dataset_is_picklable(dataset): + pickled = pickle.loads(pickle.dumps(dataset)) + + assert type(pickled) is type(dataset) + + samples = dataset[0] # tuple of (x, y) + samples = samples[0] # dict of {param_name: param_value} + samples = next(iter(samples.values())) # first param value + + pickled_samples = pickled[0] + pickled_samples = pickled_samples[0] + pickled_samples = next(iter(pickled_samples.values())) + + assert keras.ops.shape(samples) == keras.ops.shape(pickled_samples) + + +@pytest.mark.skip(reason="WIP") +def test_dataset_works_in_fit(model, dataset): + model.fit(dataset, epochs=1, steps_per_epoch=1) + + +@pytest.mark.skip(reason="WIP") +def test_dataset_returns_batch(dataset, batch_size): + samples = dataset[0] # tuple of (x, y) + samples = samples[0] # dict of {param_name: param_value} + samples = next(iter(samples.values())) # first param value + + assert keras.ops.shape(samples)[0] == batch_size diff --git a/tests/test_datasets/test_online_dataset.py b/tests/test_datasets/test_online_dataset.py deleted file mode 100644 index b6dbda7f..00000000 --- a/tests/test_datasets/test_online_dataset.py +++ /dev/null @@ -1,2 +0,0 @@ -# TODO - diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index bab450e4..ede8c1ef 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -5,27 +5,13 @@ import numpy as np import pytest -<<<<<<< HEAD -from tests.utils import assert_layers_equal - - -@pytest.mark.parametrize("automatic", [True, False]) -def test_build(automatic, inference_network, random_samples): - assert inference_network.built is False - - if automatic: - inference_network(random_samples) - else: - inference_network.build(keras.ops.shape(random_samples)) -======= -from tests.utils import allclose, assert_models_equal +from tests.utils import allclose, assert_layers_equal def test_build(inference_network, random_samples, random_conditions): assert inference_network.built is False inference_network(random_samples, conditions=random_conditions) ->>>>>>> streamlined-backend assert inference_network.built is True @@ -33,93 +19,20 @@ def test_build(inference_network, random_samples, random_conditions): assert inference_network.variables, "Model has no variables." -<<<<<<< HEAD -def test_variable_batch_size(inference_network, random_samples): - # build with one batch size - inference_network.build(keras.ops.shape(random_samples)) - - # run with another batch size - for _ in range(10): - batch_size = np.random.randint(1, 10) - new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:]) - inference_network(new_input) - inference_network(new_input, inverse=True) - - -def test_output_structure(inference_network, random_input): - output = inference_network(random_input) - - assert isinstance(output, tuple) - assert len(output) == 2 - - forward_output, forward_log_det = output - - assert keras.ops.is_tensor(forward_output) - assert keras.ops.is_tensor(forward_log_det) - - -def test_output_shape(inference_network, random_input): - forward_output, forward_log_det = inference_network(random_input) - - assert keras.ops.shape(forward_output) == keras.ops.shape(random_input) - assert keras.ops.shape(forward_log_det) == (keras.ops.shape(random_input)[0],) - - inverse_output, inverse_log_det = inference_network(random_input, inverse=True) - - assert keras.ops.shape(inverse_output) == keras.ops.shape(random_input) - assert keras.ops.shape(inverse_log_det) == (keras.ops.shape(random_input)[0],) - - -def test_cycle_consistency(inference_network, random_samples): - # cycle-consistency means the forward and inverse methods are inverses of each other - forward_output, forward_log_det = inference_network(random_samples, jacobian=True) - inverse_output, inverse_log_det = inference_network(forward_output, inverse=True, jacobian=True) - - assert keras.ops.all(keras.ops.isclose(random_samples, inverse_output)) - assert keras.ops.all(keras.ops.isclose(forward_log_det, -inverse_log_det)) - - -@pytest.mark.torch -def test_jacobian_numerically(inference_network, random_input): - import torch - - forward_output, forward_log_det = inference_network(random_input, jacobian=True) - numerical_forward_jacobian, _ = torch.autograd.functional.jacobian(inference_network, random_input, vectorize=True) - - # TODO: torch is somehow permuted wrt keras - numerical_forward_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_forward_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])] - numerical_forward_log_det = keras.ops.stack(numerical_forward_log_det, axis=0) - - assert keras.ops.all(keras.ops.isclose(forward_log_det, numerical_forward_log_det)) - - inverse_output, inverse_log_det = inference_network(random_input, inverse=True, jacobian=True) - - numerical_inverse_jacobian, _ = torch.autograd.functional.jacobian(functools.partial(inference_network, inverse=True), random_input, vectorize=True) - - # TODO: torch is somehow permuted wrt keras - numerical_inverse_log_det = [keras.ops.log(keras.ops.abs(keras.ops.det(numerical_inverse_jacobian[i, :, i, :]))) for i in range(keras.ops.shape(random_input)[0])] - numerical_inverse_log_det = keras.ops.stack(numerical_inverse_log_det, axis=0) - - assert keras.ops.all(keras.ops.isclose(inverse_log_det, numerical_inverse_log_det)) - - -def test_serialize_deserialize(tmp_path, inference_network, random_samples): - inference_network.build(keras.ops.shape(random_samples)) -======= def test_variable_batch_size(inference_network, random_samples, random_conditions): # build with one batch size inference_network(random_samples, conditions=random_conditions) # run with another batch size batch_sizes = np.random.choice(10, replace=False, size=3) - for batch_size in batch_sizes: - new_input = keras.ops.zeros((batch_size,) + keras.ops.shape(random_samples)[1:]) + for bs in batch_sizes: + new_input = keras.ops.zeros((bs,) + keras.ops.shape(random_samples)[1:]) if random_conditions is None: new_conditions = None else: - new_conditions = keras.ops.zeros((batch_size,), + keras.ops.shape(random_conditions)[1:]) + new_conditions = keras.ops.zeros((bs,) + keras.ops.shape(random_conditions)[1:]) - inference_network(new_input) + inference_network(new_input, conditions=new_conditions) inference_network(new_input, conditions=new_conditions, inverse=True) @@ -194,13 +107,8 @@ def f(x): def test_serialize_deserialize(tmp_path, inference_network, random_samples, random_conditions): # to save, the model must be built inference_network(random_samples, conditions=random_conditions) ->>>>>>> streamlined-backend keras.saving.save_model(inference_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") -<<<<<<< HEAD assert_layers_equal(inference_network, loaded) -======= - assert_models_equal(inference_network, loaded) ->>>>>>> streamlined-backend diff --git a/tests/test_two_moons/conftest.py b/tests/test_two_moons/conftest.py index 40bf110d..c5badfc9 100644 --- a/tests/test_two_moons/conftest.py +++ b/tests/test_two_moons/conftest.py @@ -1,71 +1,51 @@ import math +import keras import pytest -from keras import ops as K -from keras import random as R import bayesflow.experimental as bf @pytest.fixture() -def context(): - class ContextPrior: - def sample(self, batch_shape): - r = R.normal(shape=batch_shape + (1,), mean=0.1, stddev=0.01) - alpha = R.uniform(shape=batch_shape + (1,), minval=-0.5 * math.pi, maxval=0.5 * math.pi) - - return dict(r=r, alpha=alpha) +def batch_size(): + return 32 @pytest.fixture() -def prior(): - class Prior: +def simulator(): + class Simulator: def sample(self, batch_shape): - theta = R.uniform(shape=batch_shape + (2,), minval=-1.0, maxval=1.0) - - return dict(theta=theta) + r = keras.random.normal(shape=batch_shape + (1,), mean=0.1, stddev=0.01) + alpha = keras.random.uniform(shape=batch_shape + (1,), minval=-0.5 * math.pi, maxval=0.5 * math.pi) + theta = keras.random.uniform(shape=batch_shape + (2,), minval=-1.0, maxval=1.0) - return Prior() - - -@pytest.fixture() -def likelihood(): - class Likelihood: - def sample(self, batch_shape, r, alpha, theta): - x1 = -K.abs(theta[0] + theta[1]) / K.sqrt(2.0) + r * K.cos(alpha) + 0.25 - x2 = (-theta[0] + theta[1]) / K.sqrt(2.0) + r * K.sin(alpha) - return dict(x=K.stack([x1, x2], axis=-1)) + x1 = -keras.ops.abs(theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.cos(alpha) + 0.25 + x2 = (-theta[0] + theta[1]) / keras.ops.sqrt(2.0) + r * keras.ops.sin(alpha) - return Likelihood() + x = keras.ops.stack([x1, x2], axis=-1) + return dict(r=r, alpha=alpha, theta=theta, x=x) -@pytest.fixture() -def joint_distribution(context, prior, likelihood): - return bf.simulation.JointDistribution(context, prior, likelihood) + return Simulator() @pytest.fixture() def dataset(joint_distribution): - # TODO: do not use hard-coded batch size return bf.datasets.OnlineDataset(joint_distribution, workers=4, use_multiprocessing=True, max_queue_size=16, batch_size=16) @pytest.fixture() def inference_network(): - return bf.networks.CouplingFlow.all_in_one( - subnet_builder="default", - target_dim=2, - num_layers=2, - transform="affine", - base_distribution="normal", - ) - - -@pytest.fixture() -def summary_network(): - return None + return bf.networks.CouplingFlow() @pytest.fixture() -def amortizer(inference_network, summary_network): - return bf.AmortizedPosterior(inference_network, summary_network) +def approximator(inference_network): + return bf.Approximator( + inference_network=inference_network, + inference_variables=["theta"], + inference_conditions=["x", "r", "alpha"], + summary_network=None, + summary_variables=[], + summary_conditions=[], + ) diff --git a/tests/test_two_moons/test_fit.py b/tests/test_two_moons/test_fit.py index 9c0d4397..d0687127 100644 --- a/tests/test_two_moons/test_fit.py +++ b/tests/test_two_moons/test_fit.py @@ -14,7 +14,7 @@ def test_compile(amortizer): def test_fit(amortizer, dataset): # TODO: verify the model learns something by comparing a metric before and after training amortizer.compile(optimizer="AdamW") - amortizer.fit(dataset, epochs=10, steps_per_epoch=10, batch_size=32) + amortizer.fit(dataset, epochs=10, steps_per_epoch=10) @pytest.mark.skip(reason="not implemented") diff --git a/tests/utils/assertions.py b/tests/utils/assertions.py index ba6313eb..b121bac8 100644 --- a/tests/utils/assertions.py +++ b/tests/utils/assertions.py @@ -2,12 +2,6 @@ import keras -<<<<<<< HEAD -def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): - assert layer1.variables, "Layer has no variables." - for v1, v2 in zip(layer1.variables, layer2.variables): - assert keras.ops.all(keras.ops.isclose(v1, v2)) -======= def assert_models_equal(model1: keras.Model, model2: keras.Model): assert isinstance(model1, keras.Model) assert isinstance(model2, keras.Model) @@ -20,15 +14,16 @@ def assert_models_equal(model1: keras.Model, model2: keras.Model): def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer): + assert layer1.name == layer2.name assert len(layer1.variables) == len(layer2.variables), f"Layers {layer1.name} and {layer2.name} have a different number of variables ({len(layer1.variables)}, {len(layer2.variables)})." assert len(layer1.variables) > 0, f"Layers {layer1.name} and {layer2.name} have no variables." for v1, v2 in zip(layer1.variables, layer2.variables): + assert v1.name == v2.name - if v1.name == "seed_generator_state" and v1.name == v2.name: + if v1.name == "seed_generator_state": # keras issue: https://github.com/keras-team/keras/issues/19796 continue - v1 = keras.ops.convert_to_numpy(v1) - v2 = keras.ops.convert_to_numpy(v2) - assert keras.ops.all(keras.ops.isclose(v1, v2)), f"Variables for {layer1.name} and {layer2.name} are not equal: {v1} != {v2}" ->>>>>>> streamlined-backend + x1 = keras.ops.convert_to_numpy(v1) + x2 = keras.ops.convert_to_numpy(v2) + assert keras.ops.all(keras.ops.isclose(x1, x2)), f"Variable '{v1.name}' for Layer '{layer1.name}' is not equal: {x1} != {x2}"