From 28d19e325b2a09dab1656d4a645d936c47a371d5 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Thu, 27 Apr 2023 21:11:15 +0200 Subject: [PATCH 01/14] fixed scaling in gradient computation --- geodesic_shooting/geodesic_shooting.py | 4 +++- geodesic_shooting/utils/time_integration.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/geodesic_shooting/geodesic_shooting.py b/geodesic_shooting/geodesic_shooting.py index 2691205..78e575e 100644 --- a/geodesic_shooting/geodesic_shooting.py +++ b/geodesic_shooting/geodesic_shooting.py @@ -193,6 +193,7 @@ def gradient_descent(func, x0, grad_norm_tol=1e-5, rel_func_update_tol=1e-6, max def line_search(x, func_x, grad_x, d): alpha = alpha0 d_dot_grad = d.dot(grad_x) + assert d_dot_grad < 0., 'The direction d is not a descent direction!' func_x_update = func(x + alpha * d, compute_grad=False) k = 0 while (not func_x_update <= func_x + c1 * alpha * d_dot_grad) and k < maxiter_armijo: @@ -387,7 +388,7 @@ def rhs_function(x): # perform forward in time integration of initial vector field for t in range(0, self.time_steps-1): - # perform the explicit Euler integration step + # perform the time integration step using the time integrator vector_fields[t+1] = ti.step(vector_fields[t]) return vector_fields @@ -461,6 +462,7 @@ def rhs_function(x, v): # perform backward in time integration of the gradient of the energy function for t in range(self.time_steps-2, -1, -1): + # perform the backward time integration step using the time integrator delta_v, v_old = ti.step_backwards([delta_v, v_old], {'v': vector_fields[t]}) # return adjoint variable `delta_v` that corresponds to the gradient diff --git a/geodesic_shooting/utils/time_integration.py b/geodesic_shooting/utils/time_integration.py index f3c435c..7e08e26 100644 --- a/geodesic_shooting/utils/time_integration.py +++ b/geodesic_shooting/utils/time_integration.py @@ -45,5 +45,5 @@ def _rhs(self, x, additional_args={}, sign=1): temp = x + sign * self.dt * k3 k4 = self.f(temp, **additional_args) if isinstance(x, list): - return [self.dt/6. * (l1 + 2.*l2 + 2.*l3 + l4) for l1, l2, l3, l4 in zip(k1, k2, k3, k4)] + return [(l1 + 2.*l2 + 2.*l3 + l4) / 6. for l1, l2, l3, l4 in zip(k1, k2, k3, k4)] return (k1 + 2.*k2 + 2.*k3 + k4) / 6. From 6bbb2aa56e93d820a1f93604be989f7acb7ca204 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 10:21:21 +0200 Subject: [PATCH 02/14] added plot of singular values to summary --- geodesic_shooting/utils/summary.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/geodesic_shooting/utils/summary.py b/geodesic_shooting/utils/summary.py index 73bf55c..d9cf946 100644 --- a/geodesic_shooting/utils/summary.py +++ b/geodesic_shooting/utils/summary.py @@ -57,7 +57,6 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs plt.show() _, singular_values = pod(results['vector_fields'], return_singular_values='all') - print(singular_values) plt.semilogy(singular_values) plt.title("Singular values of time-evolution of the vector field") plt.show() @@ -200,6 +199,11 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in diff = results['target'] - results['transformed_input'] diff.save(filepath + 'difference.png', title='Difference of target and result' + postfix, figsize=figsize, show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) + _, singular_values = pod(results['vector_fields'], return_singular_values='all') + plt.semilogy(singular_values) + plt.title('Singular values of time-evolution of the vector field') + plt.savefig(filepath + 'singular_values_time_evolution_vector_field.png') + plt.close() results['initial_vector_field'].save(filepath + 'initial_vector_field.png', plot_type='default', plot_args={'title': 'Initial vector field' + postfix, 'interval': interval, 'color_length': True, 'scale': None, 'figsize': figsize}) From b4bcaef3df6752bad84fc538cf1f79fdc3d9752a Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 10:30:54 +0200 Subject: [PATCH 03/14] new test for regularizer and its inverse --- tests/test_regularizer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index 77ac884..a8a4903 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -1,12 +1,13 @@ import numpy as np import pytest +from geodesic_shooting.utils.create_example_images import make_circle from geodesic_shooting.utils.regularizer import BiharmonicRegularizer from geodesic_shooting.core import VectorField def test_regularizer_self_adjoint(): - regularizer = BiharmonicRegularizer(alpha=1, exponent=2) + regularizer = BiharmonicRegularizer(alpha=10., exponent=1, gamma=2.) v = VectorField((6, 4)) v[2, 2, 0] = 1. w = VectorField(v.spatial_shape) @@ -17,6 +18,17 @@ def test_regularizer_self_adjoint(): assert np.isclose(wLv, vLw) +def test_regularizer_inverse(): + regularizer = BiharmonicRegularizer(alpha=10., exponent=1, gamma=2.) + image = make_circle((64, 64), np.array([32, 32]), 10) + vector_field = image.grad + + assert (regularizer.cauchy_navier_inverse(regularizer.cauchy_navier(vector_field)) + - regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field))).norm < 1e-7 + assert (regularizer.cauchy_navier_inverse(regularizer.cauchy_navier(vector_field)) - vector_field).norm < 1e-7 + assert (regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field)) - vector_field).norm < 1e-7 + + @pytest.mark.parametrize("alpha", [1., 0.1, 0.01]) @pytest.mark.parametrize("exponent", [1]) @pytest.mark.parametrize("gamma", [10., 1., 0.1]) From 5c4d9ba7d8150a327dae6f5f52f889c9c6520937 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 11:53:41 +0200 Subject: [PATCH 04/14] extended test of regularizer --- tests/test_regularizer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index a8a4903..b5e1e2b 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -29,6 +29,18 @@ def test_regularizer_inverse(): assert (regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field)) - vector_field).norm < 1e-7 +def test_regularizer_in_norm(): + regularizer = BiharmonicRegularizer(alpha=10., exponent=1, gamma=2.) + image = make_circle((64, 64), np.array([32, 32]), 10) + vector_field = image.grad + + assert (regularizer.helmholtz(regularizer.helmholtz(vector_field)) + - regularizer.cauchy_navier(vector_field)).norm < 1e-4 + + assert np.abs(regularizer.helmholtz(vector_field).norm + - vector_field.get_norm(product_operator=regularizer.cauchy_navier)) < 1e-14 + + @pytest.mark.parametrize("alpha", [1., 0.1, 0.01]) @pytest.mark.parametrize("exponent", [1]) @pytest.mark.parametrize("gamma", [10., 1., 0.1]) From df0f95a0498048ed998b4c51305b6ef5524d061c Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 11:54:17 +0200 Subject: [PATCH 05/14] assert that order of norm is correct if product operator is used --- geodesic_shooting/core/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/geodesic_shooting/core/base.py b/geodesic_shooting/core/base.py index 563e5bb..2853f3a 100644 --- a/geodesic_shooting/core/base.py +++ b/geodesic_shooting/core/base.py @@ -114,6 +114,7 @@ def get_norm(self, product_operator=None, order=None, restriction=np.s_[...]): """ vol = 1. / tuple_product(self.spatial_shape) if product_operator: + assert order is None or order == 2 apply_product_operator = product_operator(self).to_numpy()[restriction].flatten() return np.sqrt(apply_product_operator.dot(self.to_numpy()[restriction].flatten())) * np.sqrt(vol) else: From 826dea9a39f9d35e113d7008139dfc57caf111a9 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 11:55:09 +0200 Subject: [PATCH 06/14] adjusted norm computation for energy during geodesic shotting --- geodesic_shooting/geodesic_shooting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/geodesic_shooting/geodesic_shooting.py b/geodesic_shooting/geodesic_shooting.py index 78e575e..97124f6 100644 --- a/geodesic_shooting/geodesic_shooting.py +++ b/geodesic_shooting/geodesic_shooting.py @@ -153,9 +153,10 @@ def energy_and_gradient(v0, compute_grad=True, return_all_energies=False): # compute the current energy consisting of intensity difference # and regularization - energy_regularizer = self.regularizer.helmholtz(v0).get_norm(restriction=restriction)**2 + energy_regularizer = v0.get_norm(product_operator=self.regularizer.cauchy_navier, + restriction=restriction)**2 energy_intensity_unscaled = compute_energy(forward_pushed_input) - energy_intensity = 1 / sigma**2 * energy_intensity_unscaled + energy_intensity = energy_intensity_unscaled / sigma**2 energy = energy_regularizer + energy_intensity if compute_grad: From bcf2abeb808422ff902241706223ddf2d2fb784c Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 11:57:49 +0200 Subject: [PATCH 07/14] added some more plots in circle scaling example --- .../circle_scaling_example.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/geodesic_shooting/circle_scaling_example.py b/examples/geodesic_shooting/circle_scaling_example.py index 8b2d20d..996f7e0 100644 --- a/examples/geodesic_shooting/circle_scaling_example.py +++ b/examples/geodesic_shooting/circle_scaling_example.py @@ -1,6 +1,7 @@ import numpy as np import geodesic_shooting +from geodesic_shooting.core import VectorField, Diffeomorphism, TimeDependentScalarFunction from geodesic_shooting.utils.create_example_images import make_circle from geodesic_shooting.utils.summary import plot_registration_results, save_plots_registration_results @@ -16,8 +17,21 @@ optimizer_options={'disp': True, 'maxiter': 20}) plot_registration_results(result) - gs.regularizer.cauchy_navier(result['initial_vector_field']).plot(title='Cauchy Navier operator applied ' - 'to initial vector field') + cn_initial = gs.regularizer.cauchy_navier(result['initial_vector_field']) + cn_initial.plot(title='Cauchy Navier operator applied to initial vector field', color_length=True) + cn_initial.get_magnitude().plot(title='Magnitude of Cauchy Navier operator applied to initial vector field') + + magnitude_evo_cn_vector_fields = [] + evolution_transformed_template = [] + for vf, diffeo in zip(result['vector_fields'].to_numpy(), result['vector_fields'].integrate(get_time_dependent_diffeomorphism=True).to_numpy()): + magnitude_evo_cn_vector_fields.append(gs.regularizer.cauchy_navier(VectorField(data=vf)).get_magnitude()) + evolution_transformed_template.append(template.push_forward(Diffeomorphism(data=diffeo))) + + magnitude_evo_cn_vector_fields = TimeDependentScalarFunction(data=magnitude_evo_cn_vector_fields) + ani1 = magnitude_evo_cn_vector_fields.animate(title='Evolution of magnitude of Cauchy Navier operator applied to vector fields') + evolution_transformed_template = TimeDependentScalarFunction(data=evolution_transformed_template) + ani2 = evolution_transformed_template.animate(title='Evolution of transformed template') import matplotlib.pyplot as plt plt.show() + save_plots_registration_results(result, filepath='results_circle_scaling/') From 86b7d04d102f097ec8414054f22735b9f399f243 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 15:31:34 +0200 Subject: [PATCH 08/14] adjusted regularization parameters in examples and tests --- examples/geodesic_shooting/moving_square_example.py | 5 +++-- examples/geodesic_shooting/square_to_circle_example.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/geodesic_shooting/moving_square_example.py b/examples/geodesic_shooting/moving_square_example.py index 5746bf9..672a71a 100644 --- a/examples/geodesic_shooting/moving_square_example.py +++ b/examples/geodesic_shooting/moving_square_example.py @@ -11,8 +11,9 @@ template = make_square((64, 64), np.array([32, 32]), 30) # perform the registration - gs = geodesic_shooting.GeodesicShooting(alpha=0.05, exponent=1) - result = gs.register(template, target, sigma=0.05, return_all=True) + gs = geodesic_shooting.GeodesicShooting(alpha=10., exponent=1, gamma=2.) + result = gs.register(template, target, sigma=1, return_all=True, optimization_method='GD', + optimizer_options={'disp': True, 'maxiter': 20}) plot_registration_results(result, frequency=5) save_plots_registration_results(result, filepath='results_moving_square/') diff --git a/examples/geodesic_shooting/square_to_circle_example.py b/examples/geodesic_shooting/square_to_circle_example.py index b335f1d..4c5ab13 100644 --- a/examples/geodesic_shooting/square_to_circle_example.py +++ b/examples/geodesic_shooting/square_to_circle_example.py @@ -15,7 +15,8 @@ # perform the registration gs = geodesic_shooting.GeodesicShooting(alpha=0.01, exponent=1) - result = gs.register(template, target, sigma=0.01, return_all=True, restriction=restriction) + result = gs.register(template, target, sigma=0.01, return_all=True, restriction=restriction, + optimization_method='GD', optimizer_options={'maxiter': 20}) result['initial_vector_field'].save_tikz('initial_vector_field_square_to_circle.tex', title="Initial vector field square to circle", From 9eb9595f51da5f0b856eb9b796cb459520d729aa Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Fri, 28 Apr 2023 20:22:22 +0200 Subject: [PATCH 09/14] fixed linting errors --- examples/geodesic_shooting/circle_scaling_example.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/geodesic_shooting/circle_scaling_example.py b/examples/geodesic_shooting/circle_scaling_example.py index 996f7e0..4751b30 100644 --- a/examples/geodesic_shooting/circle_scaling_example.py +++ b/examples/geodesic_shooting/circle_scaling_example.py @@ -23,12 +23,14 @@ magnitude_evo_cn_vector_fields = [] evolution_transformed_template = [] - for vf, diffeo in zip(result['vector_fields'].to_numpy(), result['vector_fields'].integrate(get_time_dependent_diffeomorphism=True).to_numpy()): + for vf, diffeo in zip(result['vector_fields'].to_numpy(), + result['vector_fields'].integrate(get_time_dependent_diffeomorphism=True).to_numpy()): magnitude_evo_cn_vector_fields.append(gs.regularizer.cauchy_navier(VectorField(data=vf)).get_magnitude()) evolution_transformed_template.append(template.push_forward(Diffeomorphism(data=diffeo))) magnitude_evo_cn_vector_fields = TimeDependentScalarFunction(data=magnitude_evo_cn_vector_fields) - ani1 = magnitude_evo_cn_vector_fields.animate(title='Evolution of magnitude of Cauchy Navier operator applied to vector fields') + ani1 = magnitude_evo_cn_vector_fields.animate(title='Evolution of magnitude of Cauchy Navier operator ' + 'applied to vector fields') evolution_transformed_template = TimeDependentScalarFunction(data=evolution_transformed_template) ani2 = evolution_transformed_template.animate(title='Evolution of transformed template') import matplotlib.pyplot as plt From d88d752102a7b2c349036ba7a886e95c8a6f5159 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Tue, 2 May 2023 09:18:20 +0200 Subject: [PATCH 10/14] fixed translation test and example --- tests/test_translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_translation.py b/tests/test_translation.py index 279dfc6..ade1c30 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -1,7 +1,7 @@ import numpy as np import geodesic_shooting -from geodesic_shooting.utils.create_example_images import make_circle, make_square +from geodesic_shooting.utils.create_example_images import make_circle def test_translation(): From 40d837f0569aaa0c5180abaef8a54aa3f339e72c Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Thu, 4 May 2023 12:04:08 +0200 Subject: [PATCH 11/14] improved summary by extending output --- .../square_to_circle_example.py | 2 +- geodesic_shooting/core/vector_fields.py | 2 +- geodesic_shooting/utils/summary.py | 237 +++++++++++++----- 3 files changed, 170 insertions(+), 71 deletions(-) diff --git a/examples/geodesic_shooting/square_to_circle_example.py b/examples/geodesic_shooting/square_to_circle_example.py index 4c5ab13..9595a28 100644 --- a/examples/geodesic_shooting/square_to_circle_example.py +++ b/examples/geodesic_shooting/square_to_circle_example.py @@ -23,4 +23,4 @@ interval=2, scale=100) plot_registration_results(result, frequency=5) - save_plots_registration_results(result, filepath='results_square_to_circle/') + save_plots_registration_results(result, filepath='results_square_to_circle/', save_animations=True) diff --git a/geodesic_shooting/core/vector_fields.py b/geodesic_shooting/core/vector_fields.py index fd0c8d0..afff866 100644 --- a/geodesic_shooting/core/vector_fields.py +++ b/geodesic_shooting/core/vector_fields.py @@ -677,7 +677,7 @@ def integrate_backward(self, sampler_options={'order': 1, 'mode': 'edge'}, get_t return TimeDependentDiffeomorphism(data=diffeomorphisms) return diffeomorphisms[-1] - def animate(self, title="", interval=1, color_length=False, colorbar=True, scale=None, show_axis=True, + def animate(self, title="", interval=1, color_length=True, colorbar=True, scale=None, show_axis=True, figsize=(10, 10)): """Animates the `TimeDependentVectorField` using the `plot`-function of `VectorField`. diff --git a/geodesic_shooting/utils/summary.py b/geodesic_shooting/utils/summary.py index d9cf946..73f9ada 100644 --- a/geodesic_shooting/utils/summary.py +++ b/geodesic_shooting/utils/summary.py @@ -26,28 +26,52 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs diffeomorphism = results['flow'] dim = results['input'].dim + diffeomorphism.set_inverse(results['vector_fields'].integrate_backward()) + inverse_transformed_registration_result = results['transformed_input'].push_forward(diffeomorphism.inverse) + diff_inv_reg_res = results['input'] - inverse_transformed_registration_result + inverse_transformed_target = results['target'].push_forward(diffeomorphism.inverse) + diff_inv_tar = results['input'] - inverse_transformed_target + diff = results['target'] - results['transformed_input'] + + rest = results['restriction'] + + _, singular_values = pod(results['vector_fields'], return_singular_values='all') + + # Print some results: + print("Relative norm of difference between target and transformed input: " + f"{diff.get_norm(restriction=rest) / results['target'].get_norm(restriction=rest):.5e}") + tab = "\t" + print(f"Singular values of time-evolution of vector fields: \n\t{tab.join(f'{s:.5e}' for s in singular_values)}") + print("Relative norm of difference between inverse transformed registration result and input: " + f"{diff_inv_reg_res.get_norm(restriction=rest) / results['input'].get_norm(restriction=rest):.5e}") + print("Relative norm of difference between inverse transformed target and input: " + f"{diff_inv_tar.get_norm(restriction=rest) / results['input'].get_norm(restriction=rest):.5e}") + print("Energies:") + print(f"\tEnergy regularizer: {results['energy_regularizer']:.5e}") + print(f"\tEnergy intensity (unscaled): {results['energy_intensity_unscaled']:.5e}") + print(f"\tEnergy intensity (scaled): {results['energy_intensity']:.5e}") + print(f"\tFull energy: {results['energy']:.5e}") + + # Show a couple of plots: if dim == 3: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, subplot_kw={'projection': '3d'}) else: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) - ax1, vals1 = results['input'].plot("Input", axis=ax1, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + ax1, vals1 = results['input'].plot("Input", axis=ax1, + show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals1, list): fig.colorbar(vals1, ax=ax1, fraction=0.046, pad=0.04) - ax2, vals2 = results['target'].plot("Target", axis=ax2, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + ax2, vals2 = results['target'].plot("Target", axis=ax2, + show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals2, list): fig.colorbar(vals2, ax=ax2, fraction=0.046, pad=0.04) - ax3, vals3 = results['transformed_input'].plot("Result", axis=ax3, figsize=figsize, + ax3, vals3 = results['transformed_input'].plot("Result", axis=ax3, show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + restriction=rest) if not isinstance(vals3, list): fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) - diff = results['target'] - results['transformed_input'] - ax4, vals4 = diff.plot("Difference of target and result", axis=ax4, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) + ax4, vals4 = diff.plot("Difference of target and result", axis=ax4, + show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals4, list): fig.colorbar(vals4, ax=ax4, fraction=0.046, pad=0.04) plt.show() @@ -56,7 +80,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs figsize=figsize) plt.show() - _, singular_values = pod(results['vector_fields'], return_singular_values='all') + plt.figure(figsize=figsize) plt.semilogy(singular_values) plt.title("Singular values of time-evolution of the vector field") plt.show() @@ -77,11 +101,10 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs else: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, - scale=scale, color_length=True, figsize=figsize) + scale=scale, color_length=True) if not isinstance(vals1, list): fig.colorbar(vals1, ax=ax1, fraction=0.046, pad=0.04) - ax2, vals2 = results['initial_vector_field'].get_magnitude().plot("Magnitude of initial vector field", - axis=ax2, figsize=figsize) + ax2, vals2 = results['initial_vector_field'].get_magnitude().plot("Magnitude of initial vector field", axis=ax2) if not isinstance(vals2, list): fig.colorbar(vals2, ax=ax2, fraction=0.046, pad=0.04) if dim > 1: @@ -135,38 +158,32 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs "Animation of the transformation of the input", interval=interval, figsize=figsize, show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + restriction=rest) plt.show() - diffeomorphism.set_inverse(results['vector_fields'].integrate_backward()) diffeomorphism.inverse.plot("Inverse diffeomorphism", interval=interval, figsize=figsize) plt.show() - inverse_transformed_registration_result = results['transformed_input'].push_forward(diffeomorphism.inverse) inverse_transformed_registration_result.plot("Inverse transformed registration result", figsize=figsize, show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + restriction=rest) plt.show() - diff = results['input'] - inverse_transformed_registration_result - diff.plot("Difference between input and inverse transformed registration result", figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) + diff_inv_reg_res.plot("Difference between input and inverse transformed registration result", figsize=figsize, + show_restriction_boundary=show_restriction_boundary, restriction=rest) plt.show() - inverse_transformed_target = results['target'].push_forward(diffeomorphism.inverse) inverse_transformed_target.plot("Inverse transformed target", figsize=figsize, - show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + show_restriction_boundary=show_restriction_boundary, restriction=rest) plt.show() - diff = results['input'] - inverse_transformed_target - diff.plot("Difference between input and inverse transformed target", figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) + diff_inv_tar.plot("Difference between input and inverse transformed target", figsize=figsize, + show_restriction_boundary=show_restriction_boundary, restriction=rest) plt.show() -def save_plots_registration_results(results, filepath='results/', postfix='', interval=1, figsize=(20, 20), - show_restriction_boundary=True, save_animations=False): +def save_plots_registration_results(results, filepath='results/', postfix='', interval=1, scale=None, dpi=100, + figsize=(20, 20), show_restriction_boundary=True, save_animations=False): """Saves some plots of the results from registration via geodesic shooting. Parameters @@ -179,6 +196,10 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in String to add to the title of all plots. interval Interval in which to sample. + dpi + The resolution in dots per inch. + See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html + for more details. figsize Width and height of the figures in inches. show_restriction_boundary @@ -186,71 +207,149 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in save_animations Determines whether to also save animations. """ + diffeomorphism = results['flow'] + diffeomorphism.set_inverse(results['vector_fields'].integrate_backward()) + inverse_transformed_registration_result = results['transformed_input'].push_forward(diffeomorphism.inverse) + diff_inv_reg_res = results['input'] - inverse_transformed_registration_result + inverse_transformed_target = results['target'].push_forward(diffeomorphism.inverse) + diff_inv_tar = results['input'] - inverse_transformed_target + diff = results['target'] - results['transformed_input'] + + rest = results['restriction'] + + _, singular_values = pod(results['vector_fields'], return_singular_values='all') + if not os.path.exists(filepath): os.makedirs(filepath) - results['input'].save(filepath + 'input.png', title='Input' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) - results['target'].save(filepath + 'target.png', title='Target' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) - results['transformed_input'].save(filepath + 'transformed_input.png', title='Result' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + # Write some results to a text file: + with open(filepath + 'results_file.txt', 'w') as f: + f.write("Relative norm of difference between target and transformed input: " + f"{diff.get_norm(restriction=rest) / results['target'].get_norm(restriction=rest):.5e}\n") + tab = "\t" + f.write("Singular values of time-evolution of vector fields: \n\t" + f"{tab.join(f'{s:.5e}' for s in singular_values)}\n") + f.write("Relative norm of difference between inverse transformed registration result and input: " + f"{diff_inv_reg_res.get_norm(restriction=rest) / results['input'].get_norm(restriction=rest):.5e}\n") + f.write("Relative norm of difference between inverse transformed target and input: " + f"{diff_inv_tar.get_norm(restriction=rest) / results['input'].get_norm(restriction=rest):.5e}\n") + f.write("Energies:\n") + f.write(f"\tEnergy regularizer: {results['energy_regularizer']:.5e}\n") + f.write(f"\tEnergy intensity (unscaled): {results['energy_intensity_unscaled']:.5e}\n") + f.write(f"\tEnergy intensity (scaled): {results['energy_intensity']:.5e}\n") + f.write(f"\tFull energy: {results['energy']:.5e}") + + # Save a couple of plots: + if results['input'].dim == 3: + fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, subplot_kw={'projection': '3d'}) + else: + fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) + ax1, vals1 = results['input'].plot("Input", axis=ax1, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + if not isinstance(vals1, list): + fig.colorbar(vals1, ax=ax1, fraction=0.046, pad=0.04) + ax2, vals2 = results['target'].plot("Target", axis=ax2, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + if not isinstance(vals2, list): + fig.colorbar(vals2, ax=ax2, fraction=0.046, pad=0.04) + ax3, vals3 = results['transformed_input'].plot("Result", axis=ax3, + show_restriction_boundary=show_restriction_boundary, + restriction=rest) + if not isinstance(vals3, list): + fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) diff = results['target'] - results['transformed_input'] - diff.save(filepath + 'difference.png', title='Difference of target and result' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) - _, singular_values = pod(results['vector_fields'], return_singular_values='all') + ax4, vals4 = diff.plot("Difference of target and result", axis=ax4, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + if not isinstance(vals4, list): + fig.colorbar(vals4, ax=ax4, fraction=0.046, pad=0.04) + fig.savefig(filepath + 'overview_images_results.png', dpi=dpi) + plt.close(fig) + + results['input'].save(filepath + 'input.png', title='Input' + postfix, figsize=figsize, dpi=dpi, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + results['target'].save(filepath + 'target.png', title='Target' + postfix, figsize=figsize, dpi=dpi, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + results['transformed_input'].save(filepath + 'transformed_input.png', title='Result' + postfix, figsize=figsize, + dpi=dpi, show_restriction_boundary=show_restriction_boundary, restriction=rest) + diff.save(filepath + 'difference.png', title='Difference of target and result' + postfix, figsize=figsize, dpi=dpi, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + + plt.figure(figsize=figsize) plt.semilogy(singular_values) plt.title('Singular values of time-evolution of the vector field') - plt.savefig(filepath + 'singular_values_time_evolution_vector_field.png') + plt.savefig(filepath + 'singular_values_time_evolution_vector_field.png', dpi=dpi, bbox_inches='tight') plt.close() - results['initial_vector_field'].save(filepath + 'initial_vector_field.png', plot_type='default', + + results['initial_vector_field'].save(filepath + 'initial_vector_field.png', dpi=dpi, plot_type='default', plot_args={'title': 'Initial vector field' + postfix, 'interval': interval, - 'color_length': True, 'scale': None, 'figsize': figsize}) + 'color_length': True, 'scale': scale, 'figsize': figsize}) results['initial_vector_field'].save_vtk(filepath + 'initial_vector_field_vtk') results['initial_vector_field'].save(filepath + 'initial_vector_field_streamlines.png', plot_type='streamlines', plot_args={'title': 'Initial vector field' + postfix, 'interval': interval, - 'color_length': True, 'scale': None, 'figsize': figsize, + 'color_length': True, 'scale': scale, 'figsize': figsize, 'density': 2}) results['initial_vector_field'].get_magnitude().save(filepath + 'initial_vector_field_magnitude.png', title='Magnitude of initial vector field' + postfix, - figsize=figsize) + figsize=figsize, dpi=dpi) results['initial_vector_field'].get_angle().save(filepath + 'initial_vector_field_angle.png', - title='Angle of initial vector field' + postfix, figsize=figsize) + title='Angle of initial vector field' + postfix, figsize=figsize, + dpi=dpi) for d in range(results['initial_vector_field'].dim): comp = results['initial_vector_field'].get_component_as_function(d) comp.save(filepath + f'initial_vector_field_component_{d}.png', - title='Initial vector field component ' + str(d) + postfix, figsize=figsize) - diffeomorphism = results['flow'] + title='Initial vector field component ' + str(d) + postfix, figsize=figsize, dpi=dpi) + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, scale=scale, + color_length=True) + if not isinstance(vals1, list): + fig.colorbar(vals1, ax=ax1, fraction=0.046, pad=0.04) + ax2, vals2 = results['vector_fields'][-1].plot("Final vector field", axis=ax2, interval=interval, scale=scale, + color_length=True) + if not isinstance(vals2, list): + fig.colorbar(vals2, ax=ax2, fraction=0.046, pad=0.04) + ax3, vals3 = (results['initial_vector_field'] - results['vector_fields'][-1]).plot("Difference", axis=ax3, + interval=interval, + scale=scale, + color_length=True) + if not isinstance(vals3, list): + fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) + fig.savefig(filepath + 'overview_vector_field_results.png', dpi=dpi) + plt.close(fig) + diffeomorphism.save(filepath + 'diffeomorphism.png', title='Diffeomorphism' + postfix, figsize=figsize, - interval=interval) - diffeomorphism.set_inverse(results['vector_fields'].integrate_backward()) + interval=interval, dpi=dpi) diffeomorphism.inverse.save(filepath + 'inverse_diffeomorphism.png', title='Inverse diffeomorphism' + postfix, - figsize=figsize, interval=interval) + figsize=figsize, interval=interval, dpi=dpi) - inverse_transformed_registration_result = results['transformed_input'].push_forward(diffeomorphism.inverse) inverse_transformed_registration_result.save(filepath + 'inverse_transformed_registration_result.png', title='Inverse transformed registration result' + postfix, - figsize=figsize, show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) - diff = results['input'] - inverse_transformed_registration_result - diff.save(filepath + 'diff_input_inverse_transformed_registration_result.png', - title='Difference between input and inverse transformed registration result' + postfix, - figsize=figsize, show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) + figsize=figsize, dpi=dpi, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + diff_inv_reg_res.save(filepath + 'diff_input_inverse_transformed_registration_result.png', + title='Difference between input and inverse transformed registration result' + postfix, + figsize=figsize, dpi=dpi, show_restriction_boundary=show_restriction_boundary, + restriction=rest) inverse_transformed_target = results['target'].push_forward(diffeomorphism.inverse) inverse_transformed_target.save(filepath + 'inverse_transformed_target.png', - title='Inverse transformed target' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) - diff = results['input'] - inverse_transformed_target - diff.save(filepath + 'diff_input_inverse_transformed_target.png', - title='Difference between input and inverse transformed target' + postfix, figsize=figsize, - show_restriction_boundary=show_restriction_boundary, restriction=results['restriction']) - - time_dependent_diffeomorphism = results['vector_fields'].integrate(get_time_dependent_diffeomorphism=True) - assert time_dependent_diffeomorphism[-1] == diffeomorphism + title='Inverse transformed target' + postfix, figsize=figsize, dpi=dpi, + show_restriction_boundary=show_restriction_boundary, restriction=rest) + diff_inv_tar.save(filepath + 'diff_input_inverse_transformed_target.png', + title='Difference between input and inverse transformed target' + postfix, figsize=figsize, + dpi=dpi, show_restriction_boundary=show_restriction_boundary, restriction=rest) if save_animations: + ani = results['vector_fields'].animate("Time-evolution of the vector field", interval=interval, scale=scale, + figsize=figsize) + try: + ani.save(filepath + 'animation_time_evolution_vector_field.gif', writer='imagemagick', + fps=max(1, len(results['vector_fields']) // 10)) + except Exception as e: + print(f"Could not save animation! Error: {e}") + plt.close() + + time_dependent_diffeomorphism = results['vector_fields'].integrate(get_time_dependent_diffeomorphism=True) + assert time_dependent_diffeomorphism[-1] == diffeomorphism ani = time_dependent_diffeomorphism.animate('Animation of the time-evolution of the diffeomorphism' + postfix, figsize=figsize, interval=interval) try: @@ -265,7 +364,7 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in + postfix, figsize=figsize, interval=interval, show_restriction_boundary=show_restriction_boundary, - restriction=results['restriction']) + restriction=rest) try: ani.save(filepath + 'animation_transformation_input.gif', writer='imagemagick', fps=max(1, len(time_dependent_diffeomorphism) // 10)) From 688815a0db6779e44ddfd977a7a8b274e999bf88 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Thu, 4 May 2023 16:00:58 +0200 Subject: [PATCH 12/14] fixed linting error and adjusted tolerance in test --- tests/test_regularizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_regularizer.py b/tests/test_regularizer.py index b5e1e2b..19344e9 100644 --- a/tests/test_regularizer.py +++ b/tests/test_regularizer.py @@ -23,10 +23,11 @@ def test_regularizer_inverse(): image = make_circle((64, 64), np.array([32, 32]), 10) vector_field = image.grad + tolerance = 1e-6 assert (regularizer.cauchy_navier_inverse(regularizer.cauchy_navier(vector_field)) - - regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field))).norm < 1e-7 - assert (regularizer.cauchy_navier_inverse(regularizer.cauchy_navier(vector_field)) - vector_field).norm < 1e-7 - assert (regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field)) - vector_field).norm < 1e-7 + - regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field))).norm < tolerance + assert (regularizer.cauchy_navier_inverse(regularizer.cauchy_navier(vector_field)) - vector_field).norm < tolerance + assert (regularizer.cauchy_navier(regularizer.cauchy_navier_inverse(vector_field)) - vector_field).norm < tolerance def test_regularizer_in_norm(): From d63f0d23af080bd2f359bc2ff7953b4bab1ab3b7 Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Thu, 4 May 2023 16:36:30 +0200 Subject: [PATCH 13/14] fixed saving of plots with multiple subfigures in summary --- geodesic_shooting/utils/summary.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/geodesic_shooting/utils/summary.py b/geodesic_shooting/utils/summary.py index 73f9ada..e18ba8d 100644 --- a/geodesic_shooting/utils/summary.py +++ b/geodesic_shooting/utils/summary.py @@ -57,6 +57,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, subplot_kw={'projection': '3d'}) else: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) + fig.set_size_inches(40, 10) ax1, vals1 = results['input'].plot("Input", axis=ax1, show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals1, list): @@ -74,6 +75,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals4, list): fig.colorbar(vals4, ax=ax4, fraction=0.046, pad=0.04) + plt.tight_layout() plt.show() _ = results['vector_fields'].animate("Time-evolution of the vector field", interval=interval, scale=scale, @@ -100,6 +102,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize, subplot_kw={'projection': '3d'}) else: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + fig.set_size_inches(30, 10) ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, scale=scale, color_length=True) if not isinstance(vals1, list): @@ -112,6 +115,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs axis=ax3, figsize=figsize) if not isinstance(vals3, list): fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) + plt.tight_layout() plt.show() if results['initial_vector_field'].dim == 2: @@ -126,6 +130,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs plt.show() fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + fig.set_size_inches(30, 10) ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, scale=scale, color_length=True, figsize=figsize) if not isinstance(vals1, list): @@ -141,6 +146,7 @@ def plot_registration_results(results, interval=1, frequency=1, scale=None, figs figsize=figsize) if not isinstance(vals3, list): fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) + plt.tight_layout() plt.show() time_dependent_diffeomorphism = results['vector_fields'].integrate(get_time_dependent_diffeomorphism=True) @@ -244,6 +250,7 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, subplot_kw={'projection': '3d'}) else: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4) + fig.set_size_inches(40, 10) ax1, vals1 = results['input'].plot("Input", axis=ax1, show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals1, list): @@ -262,7 +269,8 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in show_restriction_boundary=show_restriction_boundary, restriction=rest) if not isinstance(vals4, list): fig.colorbar(vals4, ax=ax4, fraction=0.046, pad=0.04) - fig.savefig(filepath + 'overview_images_results.png', dpi=dpi) + plt.tight_layout() + fig.savefig(filepath + 'overview_images_results.png', dpi=dpi, bbox_inches='tight') plt.close(fig) results['input'].save(filepath + 'input.png', title='Input' + postfix, figsize=figsize, dpi=dpi, @@ -300,6 +308,23 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in title='Initial vector field component ' + str(d) + postfix, figsize=figsize, dpi=dpi) fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + fig.set_size_inches(30, 10) + ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, + scale=scale, color_length=True) + if not isinstance(vals1, list): + fig.colorbar(vals1, ax=ax1, fraction=0.046, pad=0.04) + ax2, vals2 = results['initial_vector_field'].get_magnitude().plot("Magnitude of initial vector field", axis=ax2) + if not isinstance(vals2, list): + fig.colorbar(vals2, ax=ax2, fraction=0.046, pad=0.04) + ax3, vals3 = results['initial_vector_field'].get_angle().plot("Angle of initial vector field", axis=ax3) + if not isinstance(vals3, list): + fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) + plt.tight_layout() + fig.savefig(filepath + 'overview_initial_vector_field_results.png', dpi=dpi, bbox_inches='tight') + plt.close(fig) + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3) + fig.set_size_inches(30, 10) ax1, vals1 = results['initial_vector_field'].plot("Initial vector field", axis=ax1, interval=interval, scale=scale, color_length=True) if not isinstance(vals1, list): @@ -314,7 +339,8 @@ def save_plots_registration_results(results, filepath='results/', postfix='', in color_length=True) if not isinstance(vals3, list): fig.colorbar(vals3, ax=ax3, fraction=0.046, pad=0.04) - fig.savefig(filepath + 'overview_vector_field_results.png', dpi=dpi) + plt.tight_layout() + fig.savefig(filepath + 'overview_vector_field_results.png', dpi=dpi, bbox_inches='tight') plt.close(fig) diffeomorphism.save(filepath + 'diffeomorphism.png', title='Diffeomorphism' + postfix, figsize=figsize, From 8cde6c8c86f8f33f99890fb849fe3ec918458d9b Mon Sep 17 00:00:00 2001 From: Hendrik Kleikamp Date: Mon, 24 Jul 2023 17:04:35 +0200 Subject: [PATCH 14/14] fixed and improved landmark matching Implemented Newton's method on transversality condition and fixed some bugs. Added test for commutativity in computation of differential. --- examples/landmark_shooting/2d_example.py | 23 ++-- geodesic_shooting/landmark_shooting.py | 135 +++++++++++++++-------- tests/test_landmark_matching.py | 14 +++ 3 files changed, 114 insertions(+), 58 deletions(-) diff --git a/examples/landmark_shooting/2d_example.py b/examples/landmark_shooting/2d_example.py index 6b7324b..a5765c1 100644 --- a/examples/landmark_shooting/2d_example.py +++ b/examples/landmark_shooting/2d_example.py @@ -12,33 +12,38 @@ if __name__ == "__main__": # define landmark positions - input_landmarks = np.array([[5., 3.], [4., 2.], [1., 0.], [2., 3.]]) - target_landmarks = np.array([[6., 2.], [5., 1.], [1., -1.], [2.5, 2.]]) + input_landmarks = np.array([[5./7., 5./7.], [4./7., 4./7.], [1./7., 2./7.], [2./7., 5./7.]]) + target_landmarks = np.array([[6./7., 4./7.], [5./7., 3./7.], [1./7., 1./7.], [2.5/7., 4./7.]]) # perform the registration using landmark shooting algorithm - gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 2.}) - result = gs.register(input_landmarks, target_landmarks, sigma=0.1, return_all=True, landmarks_labeled=True) + gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 0.25}) + result = gs.register(input_landmarks, target_landmarks, optimization_method='newton', + sigma=0.1, return_all=True, landmarks_labeled=True) final_momenta = result['initial_momenta'] registered_landmarks = result['registered_landmarks'] + vf = gs.get_vector_field(final_momenta, result["input_landmarks"]) + vf.plot("Vector field at initial time", color_length=True) + vf.get_magnitude().plot("Magnitude of vector field at initial time") + # plot results plot_landmark_matchings(input_landmarks, target_landmarks, registered_landmarks) plot_initial_momenta_and_landmarks(final_momenta.flatten(), registered_landmarks.flatten(), - min_x=0., max_x=7., min_y=-2., max_y=4.) + min_x=0., max_x=1., min_y=0., max_y=1.) time_evolution_momenta = result['time_evolution_momenta'] time_evolution_positions = result['time_evolution_positions'] plot_landmark_trajectories(time_evolution_momenta, time_evolution_positions, - min_x=0., max_x=7., min_y=-2., max_y=4.) + min_x=0., max_x=1., min_y=0., max_y=1.) ani = animate_landmark_trajectories(time_evolution_momenta, time_evolution_positions, - min_x=0., max_x=7., min_y=-2., max_y=4.) + min_x=0., max_x=1., min_y=0., max_y=1.) nx = 70 ny = 60 - mins = np.array([0., -2.]) - maxs = np.array([7., 5.]) + mins = np.array([0., 0.]) + maxs = np.array([1., 1.]) spatial_shape = (nx, ny) flow = gs.compute_time_evolution_of_diffeomorphisms(final_momenta, input_landmarks, mins=mins, maxs=maxs, spatial_shape=spatial_shape) diff --git a/geodesic_shooting/landmark_shooting.py b/geodesic_shooting/landmark_shooting.py index 0ef85f0..0f61b28 100644 --- a/geodesic_shooting/landmark_shooting.py +++ b/geodesic_shooting/landmark_shooting.py @@ -20,7 +20,7 @@ class LandmarkShooting: Allassonnière, Trouvé, Younes, 2005 """ def __init__(self, kernel=GaussianKernel, kwargs_kernel={}, dim=2, num_landmarks=1, - time_integrator=RK4, time_steps=30, sampler_options={'order': 1, 'mode': 'edge'}, + time_integrator=RK4, time_steps=100, sampler_options={'order': 1, 'mode': 'edge'}, log_level='INFO'): """Constructor. @@ -65,7 +65,7 @@ def __str__(self): def register(self, input_landmarks, target_landmarks, landmarks_labeled=True, kernel_dist=GaussianKernel, kwargs_kernel_dist={}, - sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True}, + sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True, 'maxiter': 1000}, initial_momenta=None, return_all=False): """Performs actual registration according to geodesic shooting algorithm for landmarks using a Hamiltonian setting. @@ -133,7 +133,7 @@ def compute_matching_function(positions): return np.linalg.norm(positions - target_landmarks.flatten())**2 def compute_gradient_matching_function(positions): - return 2. * (positions - target_landmarks.flatten()) / sigma**2 + return 2. * (positions - target_landmarks.flatten()) else: kernel_dist = kernel_dist(**kwargs_kernel_dist, scalar=True) @@ -182,17 +182,86 @@ def energy_and_gradient(initial_momenta): d_positions_1, _ = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent, positions_time_dependent) - positions = positions_time_dependent[-1] grad_g = compute_gradient_matching_function(positions) - grad = self.K(initial_positions) @ initial_momenta + d_positions_1.T @ grad_g / sigma**2 + grad = self.K(initial_positions) @ initial_momenta + grad_g @ d_positions_1 / sigma**2 return energy, grad.flatten() - # use scipy optimizer for minimizing energy function - with self.logger.block("Perform landmark matching via geodesic shooting ..."): - res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(), - method=optimization_method, jac=True, options=optimizer_options, - callback=save_current_state) + if optimization_method == 'newton' and landmarks_labeled: + # use Newton's method for minimizing energy function + def newton(x0, update_norm_tol=1e-5, rel_func_update_tol=1e-6, maxiter=50, disp=True, callback=None): + assert update_norm_tol >= 0 and rel_func_update_tol >= 0 + assert isinstance(maxiter, int) and maxiter > 0 + + def compute_update_direction(x): + momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(x, + initial_positions) + momenta = momenta_time_dependent[-1] + positions = positions_time_dependent[-1] + d_positions_1, d_momenta_1 = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent, + positions_time_dependent) + mat = d_momenta_1 + 2 * np.eye(self.size) @ d_positions_1 / sigma ** 2 + _, grad = energy_and_gradient(x) + update = np.linalg.solve(mat, momenta + (positions - target_landmarks.flatten()) / sigma ** 2) + return update + + message = '' + with self.logger.block('Starting optimization using Newton Algorithm ...'): + x = x0.flatten() + func_x, _ = energy_and_gradient(x) + old_func_x = func_x + rel_func_update = rel_func_update_tol + 1 + update = compute_update_direction(x) + norm_update = np.linalg.norm(update) + i = 0 + if disp: + self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t' + f'rel.func.upd.= {rel_func_update:.5e}') + try: + while True: + if callback is not None: + callback(np.copy(x)) + if norm_update <= update_norm_tol: + message = 'norm of update below tolerance' + break + elif rel_func_update <= rel_func_update_tol: + message = 'relative function value update below tolerance' + break + elif i >= maxiter: + message = 'maximum number of iterations reached' + break + + update = compute_update_direction(x) + x = x - update + + func_x, _ = energy_and_gradient(x) + if not np.isclose(old_func_x, 0.): + rel_func_update = abs((func_x - old_func_x) / old_func_x) + else: + rel_func_update = 0. + old_func_x = func_x + norm_update = np.linalg.norm(update) + i += 1 + if disp: + self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t' + f'rel.func.upd.= {rel_func_update:.5e}') + except KeyboardInterrupt: + message = 'optimization stopped due to keyboard interrupt' + self.logger.warning('Optimization interrupted ...') + + self.logger.info('Finished optimization ...') + result = {'x': x, 'nit': i, 'message': message} + return result + + res = newton(initial_momenta, callback=save_current_state, **optimizer_options) + elif optimization_method == 'newton' and not landmarks_labeled: + raise NotImplementedError + else: + # use scipy optimizer for minimizing energy function + with self.logger.block("Perform landmark matching via geodesic shooting ..."): + res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(), + method=optimization_method, jac=True, options=optimizer_options, + callback=save_current_state) opt['initial_momenta'] = res['x'].reshape(input_landmarks.shape) momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(res['x'], @@ -353,8 +422,8 @@ def rhs_d_positions_function(d_position, position, d_momentum, momentum): ti_d_positions = self.time_integrator(rhs_d_positions_function, self.dt) - def rhs_d_momenta_function(d_momentum, position): - return - (d_momentum @ self.DK(position) @ position + position @ self.DK(position) @ d_momentum) + def rhs_d_momenta_function(d_momentum, momentum, position): + return - 0.5 * (d_momentum @ self.DK(position) @ momentum + momentum @ self.DK(position) @ d_momentum) ti_d_momenta = self.time_integrator(rhs_d_momenta_function, self.dt) @@ -362,7 +431,8 @@ def rhs_d_momenta_function(d_momentum, position): d_positions[t+1] = ti_d_positions.step(d_positions[t], additional_args={'position': positions[t], 'd_momentum': d_momenta[t], 'momentum': momenta[t]}) - d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'position': positions[t]}) + d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'momentum': momenta[t], + 'position': positions[t]}) return d_positions[-1], d_momenta[-1] @@ -396,8 +466,8 @@ def get_vector_field(self, momenta, positions, kernel=self.kernel) for pos in np.ndindex(spatial_shape): - spatial_pos = mins + (maxs - mins) / np.array(spatial_shape) * np.array(pos) - vector_field[pos] = vf_func(spatial_pos) * np.array(spatial_shape) + spatial_pos = mins + (maxs - mins) * np.array(pos) / np.array(spatial_shape) + vector_field[pos] = vf_func(spatial_pos) return vector_field @@ -435,40 +505,7 @@ def compute_time_evolution_of_diffeomorphisms(self, initial_momenta, initial_pos for t, (m, p) in enumerate(zip(momenta, positions)): vector_fields[t] = self.get_vector_field(m, p, mins, maxs, spatial_shape) - flow = self.integrate_forward_flow(vector_fields) - - return flow - - def integrate_forward_flow(self, vector_fields): - """Computes forward integration according to given vector fields. - - Parameters - ---------- - vector_fields - `TimeDependentVectorField` containing the sequence of vector fields to integrate - in time. - - Returns - ------- - `VectorField` containing the flow at the final time. - """ - assert isinstance(vector_fields, TimeDependentVectorField) - assert vector_fields.time_steps == self.time_steps - spatial_shape = vector_fields[0].spatial_shape - # make identity grid - identity_grid = grid.coordinate_grid(spatial_shape) - - # initial flow is the identity mapping - flow = identity_grid.copy() - - def rhs_function(x, v): - return - sampler.sample(v, x, sampler_options=self.sampler_options) - - ti = self.time_integrator(rhs_function, self.dt) - - # perform forward integration - for v in vector_fields: - flow = ti.step(flow, additional_args={'v': v}) + flow = vector_fields.integrate_backward(sampler_options=self.sampler_options) return flow diff --git a/tests/test_landmark_matching.py b/tests/test_landmark_matching.py index a034863..877ac95 100644 --- a/tests/test_landmark_matching.py +++ b/tests/test_landmark_matching.py @@ -3,6 +3,20 @@ import geodesic_shooting +def test_differential_computation(): + input_landmarks = np.array([[5., 3.], [4., 2.], [1., 0.], [2., 3.]]) + target_landmarks = np.array([[6., 2.], [5., 1.], [1., -1.], [2.5, 2.]]) + + gs = geodesic_shooting.LandmarkShooting() + + w = np.zeros(gs.dim) + for i in range(gs.dim): + ei = np.zeros(gs.dim) + ei[i] = 1. + w[i] = ((gs.DK(position) @ ei) @ momentum).dot(momentum) + assert np.allclose(w, (momentum.T @ gs.DK(position) @ momentum)) + + def test_landmark_matching(): def compute_average_distance(target_landmarks, registered_landmarks): dist = 0.