Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient computation #81

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions examples/geodesic_shooting/circle_scaling_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

import geodesic_shooting
from geodesic_shooting.core import VectorField, Diffeomorphism, TimeDependentScalarFunction
from geodesic_shooting.utils.create_example_images import make_circle
from geodesic_shooting.utils.summary import plot_registration_results, save_plots_registration_results

Expand All @@ -16,8 +17,23 @@
optimizer_options={'disp': True, 'maxiter': 20})

plot_registration_results(result)
gs.regularizer.cauchy_navier(result['initial_vector_field']).plot(title='Cauchy Navier operator applied '
'to initial vector field')
cn_initial = gs.regularizer.cauchy_navier(result['initial_vector_field'])
cn_initial.plot(title='Cauchy Navier operator applied to initial vector field', color_length=True)
cn_initial.get_magnitude().plot(title='Magnitude of Cauchy Navier operator applied to initial vector field')

magnitude_evo_cn_vector_fields = []
evolution_transformed_template = []
for vf, diffeo in zip(result['vector_fields'].to_numpy(),
result['vector_fields'].integrate(get_time_dependent_diffeomorphism=True).to_numpy()):
magnitude_evo_cn_vector_fields.append(gs.regularizer.cauchy_navier(VectorField(data=vf)).get_magnitude())
evolution_transformed_template.append(template.push_forward(Diffeomorphism(data=diffeo)))

magnitude_evo_cn_vector_fields = TimeDependentScalarFunction(data=magnitude_evo_cn_vector_fields)
ani1 = magnitude_evo_cn_vector_fields.animate(title='Evolution of magnitude of Cauchy Navier operator '
'applied to vector fields')
evolution_transformed_template = TimeDependentScalarFunction(data=evolution_transformed_template)
ani2 = evolution_transformed_template.animate(title='Evolution of transformed template')
import matplotlib.pyplot as plt
plt.show()

save_plots_registration_results(result, filepath='results_circle_scaling/')
5 changes: 3 additions & 2 deletions examples/geodesic_shooting/moving_square_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
template = make_square((64, 64), np.array([32, 32]), 30)

# perform the registration
gs = geodesic_shooting.GeodesicShooting(alpha=0.05, exponent=1)
result = gs.register(template, target, sigma=0.05, return_all=True)
gs = geodesic_shooting.GeodesicShooting(alpha=10., exponent=1, gamma=2.)
result = gs.register(template, target, sigma=1, return_all=True, optimization_method='GD',
optimizer_options={'disp': True, 'maxiter': 20})

plot_registration_results(result, frequency=5)
save_plots_registration_results(result, filepath='results_moving_square/')
5 changes: 3 additions & 2 deletions examples/geodesic_shooting/square_to_circle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

# perform the registration
gs = geodesic_shooting.GeodesicShooting(alpha=0.01, exponent=1)
result = gs.register(template, target, sigma=0.01, return_all=True, restriction=restriction)
result = gs.register(template, target, sigma=0.01, return_all=True, restriction=restriction,
optimization_method='GD', optimizer_options={'maxiter': 20})

result['initial_vector_field'].save_tikz('initial_vector_field_square_to_circle.tex',
title="Initial vector field square to circle",
interval=2, scale=100)

plot_registration_results(result, frequency=5)
save_plots_registration_results(result, filepath='results_square_to_circle/')
save_plots_registration_results(result, filepath='results_square_to_circle/', save_animations=True)
23 changes: 14 additions & 9 deletions examples/landmark_shooting/2d_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,38 @@

if __name__ == "__main__":
# define landmark positions
input_landmarks = np.array([[5., 3.], [4., 2.], [1., 0.], [2., 3.]])
target_landmarks = np.array([[6., 2.], [5., 1.], [1., -1.], [2.5, 2.]])
input_landmarks = np.array([[5./7., 5./7.], [4./7., 4./7.], [1./7., 2./7.], [2./7., 5./7.]])
target_landmarks = np.array([[6./7., 4./7.], [5./7., 3./7.], [1./7., 1./7.], [2.5/7., 4./7.]])

# perform the registration using landmark shooting algorithm
gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 2.})
result = gs.register(input_landmarks, target_landmarks, sigma=0.1, return_all=True, landmarks_labeled=True)
gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 0.25})
result = gs.register(input_landmarks, target_landmarks, optimization_method='newton',
sigma=0.1, return_all=True, landmarks_labeled=True)
final_momenta = result['initial_momenta']
registered_landmarks = result['registered_landmarks']

vf = gs.get_vector_field(final_momenta, result["input_landmarks"])
vf.plot("Vector field at initial time", color_length=True)
vf.get_magnitude().plot("Magnitude of vector field at initial time")

# plot results
plot_landmark_matchings(input_landmarks, target_landmarks, registered_landmarks)

plot_initial_momenta_and_landmarks(final_momenta.flatten(), registered_landmarks.flatten(),
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

time_evolution_momenta = result['time_evolution_momenta']
time_evolution_positions = result['time_evolution_positions']
plot_landmark_trajectories(time_evolution_momenta, time_evolution_positions,
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

ani = animate_landmark_trajectories(time_evolution_momenta, time_evolution_positions,
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

nx = 70
ny = 60
mins = np.array([0., -2.])
maxs = np.array([7., 5.])
mins = np.array([0., 0.])
maxs = np.array([1., 1.])
spatial_shape = (nx, ny)
flow = gs.compute_time_evolution_of_diffeomorphisms(final_momenta, input_landmarks,
mins=mins, maxs=maxs, spatial_shape=spatial_shape)
Expand Down
1 change: 1 addition & 0 deletions geodesic_shooting/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def get_norm(self, product_operator=None, order=None, restriction=np.s_[...]):
"""
vol = 1. / tuple_product(self.spatial_shape)
if product_operator:
assert order is None or order == 2
apply_product_operator = product_operator(self).to_numpy()[restriction].flatten()
return np.sqrt(apply_product_operator.dot(self.to_numpy()[restriction].flatten())) * np.sqrt(vol)
else:
Expand Down
2 changes: 1 addition & 1 deletion geodesic_shooting/core/vector_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def integrate_backward(self, sampler_options={'order': 1, 'mode': 'edge'}, get_t
return TimeDependentDiffeomorphism(data=diffeomorphisms)
return diffeomorphisms[-1]

def animate(self, title="", interval=1, color_length=False, colorbar=True, scale=None, show_axis=True,
def animate(self, title="", interval=1, color_length=True, colorbar=True, scale=None, show_axis=True,
figsize=(10, 10)):
"""Animates the `TimeDependentVectorField` using the `plot`-function of `VectorField`.

Expand Down
9 changes: 6 additions & 3 deletions geodesic_shooting/geodesic_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ def energy_and_gradient(v0, compute_grad=True, return_all_energies=False):

# compute the current energy consisting of intensity difference
# and regularization
energy_regularizer = self.regularizer.helmholtz(v0).get_norm(restriction=restriction)**2
energy_regularizer = v0.get_norm(product_operator=self.regularizer.cauchy_navier,
restriction=restriction)**2
energy_intensity_unscaled = compute_energy(forward_pushed_input)
energy_intensity = 1 / sigma**2 * energy_intensity_unscaled
energy_intensity = energy_intensity_unscaled / sigma**2
energy = energy_regularizer + energy_intensity

if compute_grad:
Expand Down Expand Up @@ -193,6 +194,7 @@ def gradient_descent(func, x0, grad_norm_tol=1e-5, rel_func_update_tol=1e-6, max
def line_search(x, func_x, grad_x, d):
alpha = alpha0
d_dot_grad = d.dot(grad_x)
assert d_dot_grad < 0., 'The direction d is not a descent direction!'
func_x_update = func(x + alpha * d, compute_grad=False)
k = 0
while (not func_x_update <= func_x + c1 * alpha * d_dot_grad) and k < maxiter_armijo:
Expand Down Expand Up @@ -387,7 +389,7 @@ def rhs_function(x):

# perform forward in time integration of initial vector field
for t in range(0, self.time_steps-1):
# perform the explicit Euler integration step
# perform the time integration step using the time integrator
vector_fields[t+1] = ti.step(vector_fields[t])

return vector_fields
Expand Down Expand Up @@ -461,6 +463,7 @@ def rhs_function(x, v):

# perform backward in time integration of the gradient of the energy function
for t in range(self.time_steps-2, -1, -1):
# perform the backward time integration step using the time integrator
delta_v, v_old = ti.step_backwards([delta_v, v_old], {'v': vector_fields[t]})

# return adjoint variable `delta_v` that corresponds to the gradient
Expand Down
135 changes: 86 additions & 49 deletions geodesic_shooting/landmark_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Allassonnière, Trouvé, Younes, 2005
"""
def __init__(self, kernel=GaussianKernel, kwargs_kernel={}, dim=2, num_landmarks=1,
time_integrator=RK4, time_steps=30, sampler_options={'order': 1, 'mode': 'edge'},
time_integrator=RK4, time_steps=100, sampler_options={'order': 1, 'mode': 'edge'},
log_level='INFO'):
"""Constructor.

Expand Down Expand Up @@ -65,7 +65,7 @@

def register(self, input_landmarks, target_landmarks, landmarks_labeled=True,
kernel_dist=GaussianKernel, kwargs_kernel_dist={},
sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True},
sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True, 'maxiter': 1000},
initial_momenta=None, return_all=False):
"""Performs actual registration according to geodesic shooting algorithm for landmarks using
a Hamiltonian setting.
Expand Down Expand Up @@ -133,7 +133,7 @@
return np.linalg.norm(positions - target_landmarks.flatten())**2

def compute_gradient_matching_function(positions):
return 2. * (positions - target_landmarks.flatten()) / sigma**2
return 2. * (positions - target_landmarks.flatten())
else:
kernel_dist = kernel_dist(**kwargs_kernel_dist, scalar=True)

Expand Down Expand Up @@ -182,17 +182,86 @@

d_positions_1, _ = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent,
positions_time_dependent)
positions = positions_time_dependent[-1]
grad_g = compute_gradient_matching_function(positions)
grad = self.K(initial_positions) @ initial_momenta + d_positions_1.T @ grad_g / sigma**2
grad = self.K(initial_positions) @ initial_momenta + grad_g @ d_positions_1 / sigma**2

return energy, grad.flatten()

# use scipy optimizer for minimizing energy function
with self.logger.block("Perform landmark matching via geodesic shooting ..."):
res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(),
method=optimization_method, jac=True, options=optimizer_options,
callback=save_current_state)
if optimization_method == 'newton' and landmarks_labeled:
# use Newton's method for minimizing energy function
def newton(x0, update_norm_tol=1e-5, rel_func_update_tol=1e-6, maxiter=50, disp=True, callback=None):
assert update_norm_tol >= 0 and rel_func_update_tol >= 0
assert isinstance(maxiter, int) and maxiter > 0

def compute_update_direction(x):
momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(x,
initial_positions)

Check warning on line 198 in geodesic_shooting/landmark_shooting.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

geodesic_shooting/landmark_shooting.py#L198

E501: line too long (124 > 120 characters)
momenta = momenta_time_dependent[-1]
positions = positions_time_dependent[-1]
d_positions_1, d_momenta_1 = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent,
positions_time_dependent)

Check warning on line 202 in geodesic_shooting/landmark_shooting.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

geodesic_shooting/landmark_shooting.py#L202

E501: line too long (121 > 120 characters)
mat = d_momenta_1 + 2 * np.eye(self.size) @ d_positions_1 / sigma ** 2
_, grad = energy_and_gradient(x)
update = np.linalg.solve(mat, momenta + (positions - target_landmarks.flatten()) / sigma ** 2)
return update

message = ''
with self.logger.block('Starting optimization using Newton Algorithm ...'):
x = x0.flatten()
func_x, _ = energy_and_gradient(x)
old_func_x = func_x
rel_func_update = rel_func_update_tol + 1
update = compute_update_direction(x)
norm_update = np.linalg.norm(update)
i = 0
if disp:
self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t'
f'rel.func.upd.= {rel_func_update:.5e}')
try:
while True:
if callback is not None:
callback(np.copy(x))
if norm_update <= update_norm_tol:
message = 'norm of update below tolerance'
break
elif rel_func_update <= rel_func_update_tol:
message = 'relative function value update below tolerance'
break
elif i >= maxiter:
message = 'maximum number of iterations reached'
break

update = compute_update_direction(x)
x = x - update

func_x, _ = energy_and_gradient(x)
if not np.isclose(old_func_x, 0.):
rel_func_update = abs((func_x - old_func_x) / old_func_x)
else:
rel_func_update = 0.
old_func_x = func_x
norm_update = np.linalg.norm(update)
i += 1
if disp:
self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t'
f'rel.func.upd.= {rel_func_update:.5e}')
except KeyboardInterrupt:
message = 'optimization stopped due to keyboard interrupt'
self.logger.warning('Optimization interrupted ...')

self.logger.info('Finished optimization ...')
result = {'x': x, 'nit': i, 'message': message}
return result

res = newton(initial_momenta, callback=save_current_state, **optimizer_options)
elif optimization_method == 'newton' and not landmarks_labeled:
raise NotImplementedError
else:
# use scipy optimizer for minimizing energy function
with self.logger.block("Perform landmark matching via geodesic shooting ..."):
res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(),
method=optimization_method, jac=True, options=optimizer_options,
callback=save_current_state)

opt['initial_momenta'] = res['x'].reshape(input_landmarks.shape)
momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(res['x'],
Expand Down Expand Up @@ -353,16 +422,17 @@

ti_d_positions = self.time_integrator(rhs_d_positions_function, self.dt)

def rhs_d_momenta_function(d_momentum, position):
return - (d_momentum @ self.DK(position) @ position + position @ self.DK(position) @ d_momentum)
def rhs_d_momenta_function(d_momentum, momentum, position):
return - 0.5 * (d_momentum @ self.DK(position) @ momentum + momentum @ self.DK(position) @ d_momentum)

ti_d_momenta = self.time_integrator(rhs_d_momenta_function, self.dt)

for t in range(self.time_steps-1):
d_positions[t+1] = ti_d_positions.step(d_positions[t], additional_args={'position': positions[t],
'd_momentum': d_momenta[t],
'momentum': momenta[t]})
d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'position': positions[t]})
d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'momentum': momenta[t],
'position': positions[t]})

return d_positions[-1], d_momenta[-1]

Expand Down Expand Up @@ -396,8 +466,8 @@
kernel=self.kernel)

for pos in np.ndindex(spatial_shape):
spatial_pos = mins + (maxs - mins) / np.array(spatial_shape) * np.array(pos)
vector_field[pos] = vf_func(spatial_pos) * np.array(spatial_shape)
spatial_pos = mins + (maxs - mins) * np.array(pos) / np.array(spatial_shape)
vector_field[pos] = vf_func(spatial_pos)

return vector_field

Expand Down Expand Up @@ -435,40 +505,7 @@
for t, (m, p) in enumerate(zip(momenta, positions)):
vector_fields[t] = self.get_vector_field(m, p, mins, maxs, spatial_shape)

flow = self.integrate_forward_flow(vector_fields)

return flow

def integrate_forward_flow(self, vector_fields):
"""Computes forward integration according to given vector fields.

Parameters
----------
vector_fields
`TimeDependentVectorField` containing the sequence of vector fields to integrate
in time.

Returns
-------
`VectorField` containing the flow at the final time.
"""
assert isinstance(vector_fields, TimeDependentVectorField)
assert vector_fields.time_steps == self.time_steps
spatial_shape = vector_fields[0].spatial_shape
# make identity grid
identity_grid = grid.coordinate_grid(spatial_shape)

# initial flow is the identity mapping
flow = identity_grid.copy()

def rhs_function(x, v):
return - sampler.sample(v, x, sampler_options=self.sampler_options)

ti = self.time_integrator(rhs_function, self.dt)

# perform forward integration
for v in vector_fields:
flow = ti.step(flow, additional_args={'v': v})
flow = vector_fields.integrate_backward(sampler_options=self.sampler_options)

return flow

Expand Down
Loading
Loading