From 3da7be800f93032a9f96b229d9e595941f4311f4 Mon Sep 17 00:00:00 2001 From: lars Date: Thu, 6 Jun 2024 13:43:25 +0200 Subject: [PATCH] backend-specific approximators template --- .../backend_approximators/__init__.py | 2 ++ .../backend_approximators/approximator.py | 14 ++++++++++++++ .../base_approximator.py | 16 ++++++++++++++++ .../backend_approximators/jax_approximator.py | 8 ++++++++ .../numpy_approximator.py | 8 ++++++++ .../tensorflow_approximator.py | 8 ++++++++ .../torch_approximator.py | 19 +++++++++++++++++++ 7 files changed, 75 insertions(+) create mode 100644 bayesflow/experimental/backend_approximators/__init__.py create mode 100644 bayesflow/experimental/backend_approximators/approximator.py create mode 100644 bayesflow/experimental/backend_approximators/base_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/jax_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/numpy_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/tensorflow_approximator.py create mode 100644 bayesflow/experimental/backend_approximators/torch_approximator.py diff --git a/bayesflow/experimental/backend_approximators/__init__.py b/bayesflow/experimental/backend_approximators/__init__.py new file mode 100644 index 00000000..3e50858b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/__init__.py @@ -0,0 +1,2 @@ + +from .approximator import Approximator diff --git a/bayesflow/experimental/backend_approximators/approximator.py b/bayesflow/experimental/backend_approximators/approximator.py new file mode 100644 index 00000000..76112757 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/approximator.py @@ -0,0 +1,14 @@ + +import keras + +match keras.backend.backend(): + case "jax": + from .jax_approximator import JAXApproximator as Approximator + case "numpy": + from .numpy_approximator import NumpyApproximator as Approximator + case "tensorflow": + from .tensorflow_approximator import TensorFlowApproximator as Approximator + case "torch": + from .torch_approximator import TorchApproximator as Approximator + case other: + raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.") diff --git a/bayesflow/experimental/backend_approximators/base_approximator.py b/bayesflow/experimental/backend_approximators/base_approximator.py new file mode 100644 index 00000000..e17aa9ba --- /dev/null +++ b/bayesflow/experimental/backend_approximators/base_approximator.py @@ -0,0 +1,16 @@ + +import keras + +from bayesflow.experimental.types import Tensor + + +class BaseApproximator(keras.Model): + def train_step(self, data): + raise NotImplementedError + + # noinspection PyMethodOverriding + def compute_metrics(self, data: dict[str, Tensor], mode: str = "training") -> Tensor: + raise NotImplementedError + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError(f"Use compute_metrics()['loss'] instead.") diff --git a/bayesflow/experimental/backend_approximators/jax_approximator.py b/bayesflow/experimental/backend_approximators/jax_approximator.py new file mode 100644 index 00000000..376f597b --- /dev/null +++ b/bayesflow/experimental/backend_approximators/jax_approximator.py @@ -0,0 +1,8 @@ + +import jax + +from .base_approximator import BaseApproximator + + +class JAXApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/numpy_approximator.py b/bayesflow/experimental/backend_approximators/numpy_approximator.py new file mode 100644 index 00000000..ebedd52c --- /dev/null +++ b/bayesflow/experimental/backend_approximators/numpy_approximator.py @@ -0,0 +1,8 @@ + +import numpy as np + +from .base_approximator import BaseApproximator + + +class NumpyApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/tensorflow_approximator.py b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py new file mode 100644 index 00000000..f6d10d41 --- /dev/null +++ b/bayesflow/experimental/backend_approximators/tensorflow_approximator.py @@ -0,0 +1,8 @@ + +import tensorflow as tf + +from .base_approximator import BaseApproximator + + +class TensorFlowApproximator(BaseApproximator): + pass diff --git a/bayesflow/experimental/backend_approximators/torch_approximator.py b/bayesflow/experimental/backend_approximators/torch_approximator.py new file mode 100644 index 00000000..78d1729d --- /dev/null +++ b/bayesflow/experimental/backend_approximators/torch_approximator.py @@ -0,0 +1,19 @@ + +import torch + + +from .base_approximator import BaseApproximator + + +class TorchApproximator(BaseApproximator): + def train_step(self, data): + with torch.enable_grad(): + metrics = self.compute_metrics(data, mode="training") + + loss = metrics.pop("loss") + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return metrics