From 201a806f1d50f4187a0b6e2cc494c39c91e97984 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 23 Sep 2024 15:44:59 +0200 Subject: [PATCH 1/6] add checks for was solver --- src/ott/solvers/was_solver.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/ott/solvers/was_solver.py b/src/ott/solvers/was_solver.py index 573038033..36f94436f 100644 --- a/src/ott/solvers/was_solver.py +++ b/src/ott/solvers/was_solver.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union import jax @@ -49,10 +50,12 @@ def __init__( self.epsilon = epsilon if epsilon is not None else default_epsilon self.rank = rank self.linear_ot_solver = linear_ot_solver + used_kwargs = {} if self.linear_ot_solver is None: # Detect if user requests low-rank solver. In that case the # default_epsilon makes little sense, since it was designed for GW. if self.is_low_rank: + used_kwargs = dict(inspect.signature(sinkhorn_lr.LRSinkhorn).parameters) if epsilon is None: # Use default entropic regularization in LRSinkhorn if None was passed self.linear_ot_solver = sinkhorn_lr.LRSinkhorn( @@ -64,6 +67,7 @@ def __init__( rank=self.rank, epsilon=self.epsilon, **kwargs ) else: + used_kwargs = dict(inspect.signature(sinkhorn.Sinkhorn).parameters) # When using Entropic GW, epsilon is not handled inside Sinkhorn, # but rather added back to the Geometry object re-instantiated # when linearizing the problem. Therefore, no need to pass it to solver. @@ -73,6 +77,10 @@ def __init__( self.max_iterations = max_iterations self.threshold = threshold self.store_inner_errors = store_inner_errors + # assert that all kwargs are valid + if not set(kwargs.keys()).issubset(used_kwargs.keys()): + unrecognized_kwargs = set(kwargs.keys()) - set(used_kwargs.keys()) + raise TypeError(f"Invalid keyword arguments: {unrecognized_kwargs}.") self._kwargs = kwargs @property From 73e6f54daf27694302c509dfe107fd1d331a97f7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 23 Sep 2024 16:33:49 +0200 Subject: [PATCH 2/6] add tests for GWSolvers --- tests/solvers/quadratic/fgw_test.py | 107 +++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 18 deletions(-) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 9517bdce8..bed6dd327 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -28,7 +28,6 @@ class TestFusedGromovWasserstein: - # TODO(michalk8): refactor me in the future @pytest.fixture(autouse=True) def initialize(self, rng: jax.Array): @@ -60,7 +59,12 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): prob = quadratic_problem.QuadraticProblem( - geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b + geom_x, + geom_y, + geom_xy, + fused_penalty=self.fused_penalty, + a=a, + b=b, ) implicit_diff = implicit_lib.ImplicitDiff() if implicit else None @@ -96,16 +100,22 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): np.testing.assert_allclose(g_a, gi_a, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(g_b, gi_b, rtol=1e-2, atol=1e-2) - @pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False), - (False, True)], - ids=["lse-pc", "kernel-cost-mat"]) + @pytest.mark.parametrize( + ("lse_mode", "is_cost"), + [(True, False), (False, True)], + ids=["lse-pc", "kernel-cost-mat"], + ) def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jnp.ndarray, y: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], - fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + fused_penalty: float, + a: jnp.ndarray, + b: jnp.ndarray, + implicit: bool, ): if is_cost: geom_x = geometry.Geometry(cost_matrix=x) @@ -121,7 +131,9 @@ def reg_gw( implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( - lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000 + lse_mode=lse_mode, + implicit_diff=implicit_diff, + max_iterations=1000, ) solver = gromov_wasserstein.GromovWasserstein( linear_ot_solver=linear_solver, epsilon=1.0, max_iterations=10 @@ -168,7 +180,7 @@ def loss_thre(threshold: float) -> float: geom_xy, a=self.a, b=self.b, - fused_penalty=self.fused_penalty_2 + fused_penalty=self.fused_penalty_2, ) solver = gromov_wasserstein.GromovWasserstein( threshold=threshold, epsilon=1e-1 @@ -184,8 +196,13 @@ def test_gradient_fgw_solver_penalty(self): lse_mode = True def reg_gw( - cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, - fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool + cx: jnp.ndarray, + cy: jnp.ndarray, + cxy: jnp.ndarray, + fused_penalty: float, + a: jnp.ndarray, + b: jnp.ndarray, + implicit: bool, ) -> float: geom_x = geometry.Geometry(cost_matrix=cx) geom_y = geometry.Geometry(cost_matrix=cy) @@ -196,7 +213,9 @@ def reg_gw( implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( - lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=200 + lse_mode=lse_mode, + implicit_diff=implicit_diff, + max_iterations=200, ) solver = gromov_wasserstein.GromovWasserstein( epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver @@ -207,8 +226,13 @@ def reg_gw( for i, implicit in enumerate([True, False]): reg_fgw_grad = jax.grad(reg_gw, argnums=(3,)) grad_matrices[i] = reg_fgw_grad( - self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b, - implicit + self.cx, + self.cy, + self.cxy, + self.fused_penalty, + self.a, + self.b, + implicit, ) assert not jnp.any(jnp.isnan(grad_matrices[i][0])) @@ -272,7 +296,7 @@ def test_fgw_lr_generic_cost_matrix( epsilon=10.0, min_iterations=0, inner_iterations=10, - max_iterations=2000 + max_iterations=2000, ) out = solver(prob) @@ -314,7 +338,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]): geom_y, geom_xy, fused_penalty=fused_penalty, - scale_cost=scale_cost + scale_cost=scale_cost, ) solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon) @@ -344,14 +368,14 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): geom_yy, geom_xy=geom_xy, fused_penalty=fused_penalty, - store_inner_errors=True + store_inner_errors=True, ) out_fp = quadratic.solve( geom_xx, geom_yy, geom_xy=geom_xy_fp, fused_penalty=1.0, - store_inner_errors=True + store_inner_errors=True, ) np.testing.assert_allclose(out.costs, out_fp.costs, rtol=rtol, atol=atol) @@ -362,3 +386,50 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): np.testing.assert_allclose( out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol ) + + @pytest.mark.parametrize( + ( + "fused", + "lr", + ), + [ + ( + True, + False, + ), + ( + False, + True, + ), + ( + True, + True, + ), + ( + False, + False, + ), + ], + ) + def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool): + fused_penalty = 1.0 if fused else 0.0 + epsilon = 5.0 + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None + + prob = quadratic_problem.QuadraticProblem( + geom_xx=geom_x, + geom_yy=geom_y, + geom_xy=geom_xy, + fused_penalty=fused_penalty, + ) + if lr: + prob = prob.to_low_rank() + + solver_cls = ( + gromov_wasserstein_lr.LRGromovWasserstein + if lr else gromov_wasserstein.GromovWasserstein + ) + with pytest.raises(TypeError): + solver_cls(epsilon=epsilon, dummy=42)(prob) From bd6672ccd0a1aee26e6c80fb4e56f1adddb28c92 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 23 Sep 2024 23:05:55 +0200 Subject: [PATCH 3/6] maybe try to fix precommits? --- .github/workflows/lint.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6bde6ec3b..c5c87d5e6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,8 +33,11 @@ jobs: if: ${{ matrix.lint-kind == 'code' }} with: path: ~/.cache/pre-commit - key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} - + key: pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} + restore-keys: | + pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}- + pre-commit-${{ runner.os }}- + pre-commit- - name: Install dependencies run: | python -m pip install --upgrade pip From 62b75829111a44bef892803b3a967a80391af45f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 25 Sep 2024 17:25:58 +0200 Subject: [PATCH 4/6] Revert "maybe try to fix precommits?" This reverts commit bd6672ccd0a1aee26e6c80fb4e56f1adddb28c92. --- .github/workflows/lint.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c5c87d5e6..6bde6ec3b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,11 +33,8 @@ jobs: if: ${{ matrix.lint-kind == 'code' }} with: path: ~/.cache/pre-commit - key: pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} - restore-keys: | - pre-commit-${{ runner.os }}-python-${{ env.pythonLocation }}- - pre-commit-${{ runner.os }}- - pre-commit- + key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('**/.pre-commit-config.yaml') }} + - name: Install dependencies run: | python -m pip install --upgrade pip From a3b30c8f1059ac30bc2dc4bcd4a1c3db31d5e445 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 25 Sep 2024 17:27:29 +0200 Subject: [PATCH 5/6] Revert "add tests for GWSolvers" This reverts commit 73e6f54daf27694302c509dfe107fd1d331a97f7. --- tests/solvers/quadratic/fgw_test.py | 107 +++++----------------------- 1 file changed, 18 insertions(+), 89 deletions(-) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index bed6dd327..9517bdce8 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -28,6 +28,7 @@ class TestFusedGromovWasserstein: + # TODO(michalk8): refactor me in the future @pytest.fixture(autouse=True) def initialize(self, rng: jax.Array): @@ -59,12 +60,7 @@ def test_gradient_marginals_fgw_solver(self, jit: bool): def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): prob = quadratic_problem.QuadraticProblem( - geom_x, - geom_y, - geom_xy, - fused_penalty=self.fused_penalty, - a=a, - b=b, + geom_x, geom_y, geom_xy, fused_penalty=self.fused_penalty, a=a, b=b ) implicit_diff = implicit_lib.ImplicitDiff() if implicit else None @@ -100,22 +96,16 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): np.testing.assert_allclose(g_a, gi_a, rtol=1e-2, atol=1e-2) np.testing.assert_allclose(g_b, gi_b, rtol=1e-2, atol=1e-2) - @pytest.mark.parametrize( - ("lse_mode", "is_cost"), - [(True, False), (False, True)], - ids=["lse-pc", "kernel-cost-mat"], - ) + @pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False), + (False, True)], + ids=["lse-pc", "kernel-cost-mat"]) def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): """Test gradient w.r.t. the geometries.""" def reg_gw( - x: jnp.ndarray, - y: jnp.ndarray, + x: jnp.ndarray, y: jnp.ndarray, xy: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], - fused_penalty: float, - a: jnp.ndarray, - b: jnp.ndarray, - implicit: bool, + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool ): if is_cost: geom_x = geometry.Geometry(cost_matrix=x) @@ -131,9 +121,7 @@ def reg_gw( implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( - lse_mode=lse_mode, - implicit_diff=implicit_diff, - max_iterations=1000, + lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=1000 ) solver = gromov_wasserstein.GromovWasserstein( linear_ot_solver=linear_solver, epsilon=1.0, max_iterations=10 @@ -180,7 +168,7 @@ def loss_thre(threshold: float) -> float: geom_xy, a=self.a, b=self.b, - fused_penalty=self.fused_penalty_2, + fused_penalty=self.fused_penalty_2 ) solver = gromov_wasserstein.GromovWasserstein( threshold=threshold, epsilon=1e-1 @@ -196,13 +184,8 @@ def test_gradient_fgw_solver_penalty(self): lse_mode = True def reg_gw( - cx: jnp.ndarray, - cy: jnp.ndarray, - cxy: jnp.ndarray, - fused_penalty: float, - a: jnp.ndarray, - b: jnp.ndarray, - implicit: bool, + cx: jnp.ndarray, cy: jnp.ndarray, cxy: jnp.ndarray, + fused_penalty: float, a: jnp.ndarray, b: jnp.ndarray, implicit: bool ) -> float: geom_x = geometry.Geometry(cost_matrix=cx) geom_y = geometry.Geometry(cost_matrix=cy) @@ -213,9 +196,7 @@ def reg_gw( implicit_diff = implicit_lib.ImplicitDiff() if implicit else None linear_solver = sinkhorn.Sinkhorn( - lse_mode=lse_mode, - implicit_diff=implicit_diff, - max_iterations=200, + lse_mode=lse_mode, implicit_diff=implicit_diff, max_iterations=200 ) solver = gromov_wasserstein.GromovWasserstein( epsilon=1.0, max_iterations=10, linear_ot_solver=linear_solver @@ -226,13 +207,8 @@ def reg_gw( for i, implicit in enumerate([True, False]): reg_fgw_grad = jax.grad(reg_gw, argnums=(3,)) grad_matrices[i] = reg_fgw_grad( - self.cx, - self.cy, - self.cxy, - self.fused_penalty, - self.a, - self.b, - implicit, + self.cx, self.cy, self.cxy, self.fused_penalty, self.a, self.b, + implicit ) assert not jnp.any(jnp.isnan(grad_matrices[i][0])) @@ -296,7 +272,7 @@ def test_fgw_lr_generic_cost_matrix( epsilon=10.0, min_iterations=0, inner_iterations=10, - max_iterations=2000, + max_iterations=2000 ) out = solver(prob) @@ -338,7 +314,7 @@ def test_fgw_scale_cost(self, scale_cost: Literal["mean", "max_cost"]): geom_y, geom_xy, fused_penalty=fused_penalty, - scale_cost=scale_cost, + scale_cost=scale_cost ) solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon) @@ -368,14 +344,14 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): geom_yy, geom_xy=geom_xy, fused_penalty=fused_penalty, - store_inner_errors=True, + store_inner_errors=True ) out_fp = quadratic.solve( geom_xx, geom_yy, geom_xy=geom_xy_fp, fused_penalty=1.0, - store_inner_errors=True, + store_inner_errors=True ) np.testing.assert_allclose(out.costs, out_fp.costs, rtol=rtol, atol=atol) @@ -386,50 +362,3 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): np.testing.assert_allclose( out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol ) - - @pytest.mark.parametrize( - ( - "fused", - "lr", - ), - [ - ( - True, - False, - ), - ( - False, - True, - ), - ( - True, - True, - ), - ( - False, - False, - ), - ], - ) - def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool): - fused_penalty = 1.0 if fused else 0.0 - epsilon = 5.0 - geom_x = pointcloud.PointCloud(self.x) - geom_y = pointcloud.PointCloud(self.y) - geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None - - prob = quadratic_problem.QuadraticProblem( - geom_xx=geom_x, - geom_yy=geom_y, - geom_xy=geom_xy, - fused_penalty=fused_penalty, - ) - if lr: - prob = prob.to_low_rank() - - solver_cls = ( - gromov_wasserstein_lr.LRGromovWasserstein - if lr else gromov_wasserstein.GromovWasserstein - ) - with pytest.raises(TypeError): - solver_cls(epsilon=epsilon, dummy=42)(prob) From 88bde475c5c8d4a3b216c7ad08c22840fe29daa2 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 25 Sep 2024 17:34:56 +0200 Subject: [PATCH 6/6] formatting fix --- tests/solvers/quadratic/fgw_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 9517bdce8..c922adf5b 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -362,3 +362,28 @@ def test_fgw_fused_penalty(self, rng: jax.Array, fused_penalty: float): np.testing.assert_allclose( out.reg_gw_cost, out_fp.reg_gw_cost, rtol=rtol, atol=atol ) + + @pytest.mark.parametrize(("fused", "lr"), [(True, False), (False, True), + (True, True), (False, False)]) + def test_solver_unrecognized_args_fails(self, fused: bool, lr: bool): + fused_penalty = 1.0 if fused else 0.0 + epsilon = 5.0 + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + geom_xy = pointcloud.PointCloud(self.x_2, self.y_2) if fused else None + + prob = quadratic_problem.QuadraticProblem( + geom_xx=geom_x, + geom_yy=geom_y, + geom_xy=geom_xy, + fused_penalty=fused_penalty, + ) + if lr: + prob = prob.to_low_rank() + + solver_cls = ( + gromov_wasserstein_lr.LRGromovWasserstein + if lr else gromov_wasserstein.GromovWasserstein + ) + with pytest.raises(TypeError): + solver_cls(epsilon=epsilon, dummy=42)(prob)