diff --git a/black_it/loss_functions/msm.py b/black_it/loss_functions/msm.py index 710e0631..75995cc6 100644 --- a/black_it/loss_functions/msm.py +++ b/black_it/loss_functions/msm.py @@ -179,10 +179,10 @@ def compute_loss_1d( g = real_mom_1d - sim_mom_1d - if self._covariance_mat == "identity": + if self._covariance_mat == _CovarianceMatrixType.IDENTITY.value: loss_1d = g.dot(g) return loss_1d - if self._covariance_mat == "inverse_variance": + if self._covariance_mat == _CovarianceMatrixType.INVERSE_VARIANCE.value: W = np.diag( 1.0 / np.mean((real_mom_1d[None, :] - ensemble_sim_mom_1d) ** 2, axis=0) )