diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index 573038033..36f94436f 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import jax @@ -49,10 +50,12 @@ def __init__( self.epsilon = epsilon if epsilon is not None else default_epsilon self.rank = rank self.linear_ot_solver = linear_ot_solver + used_kwargs = {} if self.linear_ot_solver is None: # Detect if user requests low-rank solver. In that case the # default_epsilon makes little sense, since it was designed for GW. if self.is_low_rank: + used_kwargs = dict(inspect.signature(sinkhorn_lr.LRSinkhorn).parameters) if epsilon is None: # Use default entropic regularization in LRSinkhorn if None was passed self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( @@ -64,6 +67,7 @@ def __init__( rank=self.rank, epsilon=self.epsilon, **kwargs ) else: + used_kwargs = dict(inspect.signature(sinkhorn.Sinkhorn).parameters) # When using Entropic GW, epsilon is not handled inside Sinkhorn, # but rather added back to the Geometry object re-instantiated # when linearizing the problem. Therefore, no need to pass it to solver. @@ -73,6 +77,10 @@ def __init__( self.max_iterations = max_iterations self.threshold = threshold self.store_inner_errors = store_inner_errors + # assert that all kwargs are valid + if not set(kwargs.keys()).issubset(used_kwargs.keys()): + unrecognized_kwargs = set(kwargs.keys()) - set(used_kwargs.keys()) + raise TypeError(f"Invalid keyword arguments: {unrecognized_kwargs}.") self._kwargs = kwargs @property diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 9517bdce8..c922adf5b 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -362,3 +362,28 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): np.testing.assert_allclose( out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol ) + + @pytest.mark.parametrize(("fused", "lr"), [(True, False), (False, True), + (True, True), (False, False)]) + def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool): + fused_penalty = 1.0 if fused else 0.0 + epsilon = 5.0 + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None + + prob = quadratic_problem.QuadraticProblem( + geom_xx=geom_x, + geom_yy=geom_y, + geom_xy=geom_xy, + fused_penalty=fused_penalty, + ) + if lr: + prob = prob.to_low_rank() + + solver_cls = ( + gromov_wasserstein_lr.LRGromovWasserstein + if lr else gromov_wasserstein.GromovWasserstein + ) + with pytest.raises(TypeError): + solver_cls(epsilon=epsilon, dummy=42)(prob)