Skip to content

Commit

Permalink
initial version of RMSE metric
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 23, 2024
1 parent ea0d049 commit 21759bb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions bayesflow/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .maximum_mean_discrepancy import MaximumMeanDiscrepancy
from .root_mean_squard_error import RootMeanSquaredError
1 change: 1 addition & 0 deletions bayesflow/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .maximum_mean_discrepancy import maximum_mean_discrepancy
from .root_mean_squared_error import root_mean_squared_error
33 changes: 33 additions & 0 deletions bayesflow/metrics/functional/root_mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import keras
from keras import ops

from bayesflow.types import Tensor


def root_mean_squared_error(x1: Tensor, x2: Tensor, normalize: bool = False, **kwargs) -> Tensor:
"""Computes the (normalized) root mean squared error between samples x1 and x2.
:param x1: Tensor of shape (n, ...)
:param x2: Tensor of shape (n, ...)
:param normalize: Normalize the RMSE?
:param kwargs: Currently ignored
:return: Tensor of shape (n,)
The RMSE between x1 and x2 over all remaining dimensions.
"""

if keras.ops.shape(x1) != keras.ops.shape(x2):
raise ValueError(
f"Expected x1 and x2 to have the same dimensions, "
f"but got {keras.ops.shape(x1)} != {keras.ops.shape(x2)}."
)

# use flattened versions
x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1))
x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1))

# TODO: how to normalize the RMSE?
return ops.sqrt(ops.mean(ops.square(x1 - x2), axis=1))
11 changes: 11 additions & 0 deletions bayesflow/metrics/root_mean_squard_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from functools import partial
import keras


from .functional import maximum_mean_discrepancy


class MaximumMeanDiscrepancy(keras.metrics.MeanMetricWrapper):
def __init__(self, name="maximum_mean_discrepancy", dtype=None, **kwargs):
fn = partial(maximum_mean_discrepancy, **kwargs)
super().__init__(fn, name=name, dtype=dtype)

0 comments on commit 21759bb

Please sign in to comment.