Skip to content

Commit

Permalink
backend-specific approximators template
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 6, 2024
1 parent 4fbda46 commit 3da7be8
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bayesflow/experimental/backend_approximators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .approximator import Approximator
14 changes: 14 additions & 0 deletions bayesflow/experimental/backend_approximators/approximator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

import keras

match keras.backend.backend():
case "jax":
from .jax_approximator import JAXApproximator as Approximator
case "numpy":
from .numpy_approximator import NumpyApproximator as Approximator
case "tensorflow":
from .tensorflow_approximator import TensorFlowApproximator as Approximator
case "torch":
from .torch_approximator import TorchApproximator as Approximator
case other:
raise NotImplementedError(f"BayesFlow does not currently support backend '{other}'.")
16 changes: 16 additions & 0 deletions bayesflow/experimental/backend_approximators/base_approximator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

import keras

from bayesflow.experimental.types import Tensor


class BaseApproximator(keras.Model):
def train_step(self, data):
raise NotImplementedError

# noinspection PyMethodOverriding
def compute_metrics(self, data: dict[str, Tensor], mode: str = "training") -> Tensor:
raise NotImplementedError

def compute_loss(self, *args, **kwargs):
raise NotImplementedError(f"Use compute_metrics()['loss'] instead.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

import jax

from .base_approximator import BaseApproximator


class JAXApproximator(BaseApproximator):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

import numpy as np

from .base_approximator import BaseApproximator


class NumpyApproximator(BaseApproximator):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

import tensorflow as tf

from .base_approximator import BaseApproximator


class TensorFlowApproximator(BaseApproximator):
pass
19 changes: 19 additions & 0 deletions bayesflow/experimental/backend_approximators/torch_approximator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

import torch


from .base_approximator import BaseApproximator


class TorchApproximator(BaseApproximator):
def train_step(self, data):
with torch.enable_grad():
metrics = self.compute_metrics(data, mode="training")

loss = metrics.pop("loss")

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return metrics

0 comments on commit 3da7be8

Please sign in to comment.