-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ea0d049
commit 21759bb
Showing
4 changed files
with
46 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .maximum_mean_discrepancy import MaximumMeanDiscrepancy | ||
from .root_mean_squard_error import RootMeanSquaredError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .maximum_mean_discrepancy import maximum_mean_discrepancy | ||
from .root_mean_squared_error import root_mean_squared_error |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from bayesflow.types import Tensor | ||
|
||
|
||
def root_mean_squared_error(x1: Tensor, x2: Tensor, normalize: bool = False, **kwargs) -> Tensor: | ||
"""Computes the (normalized) root mean squared error between samples x1 and x2. | ||
:param x1: Tensor of shape (n, ...) | ||
:param x2: Tensor of shape (n, ...) | ||
:param normalize: Normalize the RMSE? | ||
:param kwargs: Currently ignored | ||
:return: Tensor of shape (n,) | ||
The RMSE between x1 and x2 over all remaining dimensions. | ||
""" | ||
|
||
if keras.ops.shape(x1) != keras.ops.shape(x2): | ||
raise ValueError( | ||
f"Expected x1 and x2 to have the same dimensions, " | ||
f"but got {keras.ops.shape(x1)} != {keras.ops.shape(x2)}." | ||
) | ||
|
||
# use flattened versions | ||
x1 = keras.ops.reshape(x1, (keras.ops.shape(x1)[0], -1)) | ||
x2 = keras.ops.reshape(x2, (keras.ops.shape(x2)[0], -1)) | ||
|
||
# TODO: how to normalize the RMSE? | ||
return ops.sqrt(ops.mean(ops.square(x1 - x2), axis=1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from functools import partial | ||
import keras | ||
|
||
|
||
from .functional import maximum_mean_discrepancy | ||
|
||
|
||
class MaximumMeanDiscrepancy(keras.metrics.MeanMetricWrapper): | ||
def __init__(self, name="maximum_mean_discrepancy", dtype=None, **kwargs): | ||
fn = partial(maximum_mean_discrepancy, **kwargs) | ||
super().__init__(fn, name=name, dtype=dtype) |