Skip to content

Commit

Permalink
fixed and improved landmark matching
Browse files Browse the repository at this point in the history
Implemented Newton's method on transversality condition and fixed some
bugs. Added test for commutativity in computation of differential.
  • Loading branch information
HenKlei committed Aug 17, 2023
1 parent d63f0d2 commit 8cde6c8
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 58 deletions.
23 changes: 14 additions & 9 deletions examples/landmark_shooting/2d_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,38 @@

if __name__ == "__main__":
# define landmark positions
input_landmarks = np.array([[5., 3.], [4., 2.], [1., 0.], [2., 3.]])
target_landmarks = np.array([[6., 2.], [5., 1.], [1., -1.], [2.5, 2.]])
input_landmarks = np.array([[5./7., 5./7.], [4./7., 4./7.], [1./7., 2./7.], [2./7., 5./7.]])
target_landmarks = np.array([[6./7., 4./7.], [5./7., 3./7.], [1./7., 1./7.], [2.5/7., 4./7.]])

# perform the registration using landmark shooting algorithm
gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 2.})
result = gs.register(input_landmarks, target_landmarks, sigma=0.1, return_all=True, landmarks_labeled=True)
gs = geodesic_shooting.LandmarkShooting(kwargs_kernel={'sigma': 0.25})
result = gs.register(input_landmarks, target_landmarks, optimization_method='newton',
sigma=0.1, return_all=True, landmarks_labeled=True)
final_momenta = result['initial_momenta']
registered_landmarks = result['registered_landmarks']

vf = gs.get_vector_field(final_momenta, result["input_landmarks"])
vf.plot("Vector field at initial time", color_length=True)
vf.get_magnitude().plot("Magnitude of vector field at initial time")

# plot results
plot_landmark_matchings(input_landmarks, target_landmarks, registered_landmarks)

plot_initial_momenta_and_landmarks(final_momenta.flatten(), registered_landmarks.flatten(),
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

time_evolution_momenta = result['time_evolution_momenta']
time_evolution_positions = result['time_evolution_positions']
plot_landmark_trajectories(time_evolution_momenta, time_evolution_positions,
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

ani = animate_landmark_trajectories(time_evolution_momenta, time_evolution_positions,
min_x=0., max_x=7., min_y=-2., max_y=4.)
min_x=0., max_x=1., min_y=0., max_y=1.)

nx = 70
ny = 60
mins = np.array([0., -2.])
maxs = np.array([7., 5.])
mins = np.array([0., 0.])
maxs = np.array([1., 1.])
spatial_shape = (nx, ny)
flow = gs.compute_time_evolution_of_diffeomorphisms(final_momenta, input_landmarks,
mins=mins, maxs=maxs, spatial_shape=spatial_shape)
Expand Down
135 changes: 86 additions & 49 deletions geodesic_shooting/landmark_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LandmarkShooting:
Allassonnière, Trouvé, Younes, 2005
"""
def __init__(self, kernel=GaussianKernel, kwargs_kernel={}, dim=2, num_landmarks=1,
time_integrator=RK4, time_steps=30, sampler_options={'order': 1, 'mode': 'edge'},
time_integrator=RK4, time_steps=100, sampler_options={'order': 1, 'mode': 'edge'},
log_level='INFO'):
"""Constructor.
Expand Down Expand Up @@ -65,7 +65,7 @@ def __str__(self):

def register(self, input_landmarks, target_landmarks, landmarks_labeled=True,
kernel_dist=GaussianKernel, kwargs_kernel_dist={},
sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True},
sigma=1., optimization_method='L-BFGS-B', optimizer_options={'disp': True, 'maxiter': 1000},
initial_momenta=None, return_all=False):
"""Performs actual registration according to geodesic shooting algorithm for landmarks using
a Hamiltonian setting.
Expand Down Expand Up @@ -133,7 +133,7 @@ def compute_matching_function(positions):
return np.linalg.norm(positions - target_landmarks.flatten())**2

def compute_gradient_matching_function(positions):
return 2. * (positions - target_landmarks.flatten()) / sigma**2
return 2. * (positions - target_landmarks.flatten())
else:
kernel_dist = kernel_dist(**kwargs_kernel_dist, scalar=True)

Expand Down Expand Up @@ -182,17 +182,86 @@ def energy_and_gradient(initial_momenta):

d_positions_1, _ = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent,
positions_time_dependent)
positions = positions_time_dependent[-1]
grad_g = compute_gradient_matching_function(positions)
grad = self.K(initial_positions) @ initial_momenta + d_positions_1.T @ grad_g / sigma**2
grad = self.K(initial_positions) @ initial_momenta + grad_g @ d_positions_1 / sigma**2

return energy, grad.flatten()

# use scipy optimizer for minimizing energy function
with self.logger.block("Perform landmark matching via geodesic shooting ..."):
res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(),
method=optimization_method, jac=True, options=optimizer_options,
callback=save_current_state)
if optimization_method == 'newton' and landmarks_labeled:
# use Newton's method for minimizing energy function
def newton(x0, update_norm_tol=1e-5, rel_func_update_tol=1e-6, maxiter=50, disp=True, callback=None):
assert update_norm_tol >= 0 and rel_func_update_tol >= 0
assert isinstance(maxiter, int) and maxiter > 0

def compute_update_direction(x):
momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(x,
initial_positions)

Check warning on line 198 in geodesic_shooting/landmark_shooting.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

geodesic_shooting/landmark_shooting.py#L198

E501: line too long (124 > 120 characters)
momenta = momenta_time_dependent[-1]
positions = positions_time_dependent[-1]
d_positions_1, d_momenta_1 = self.integrate_forward_variational_Hamiltonian(momenta_time_dependent,
positions_time_dependent)

Check warning on line 202 in geodesic_shooting/landmark_shooting.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

geodesic_shooting/landmark_shooting.py#L202

E501: line too long (121 > 120 characters)
mat = d_momenta_1 + 2 * np.eye(self.size) @ d_positions_1 / sigma ** 2
_, grad = energy_and_gradient(x)
update = np.linalg.solve(mat, momenta + (positions - target_landmarks.flatten()) / sigma ** 2)
return update

message = ''
with self.logger.block('Starting optimization using Newton Algorithm ...'):
x = x0.flatten()
func_x, _ = energy_and_gradient(x)
old_func_x = func_x
rel_func_update = rel_func_update_tol + 1
update = compute_update_direction(x)
norm_update = np.linalg.norm(update)
i = 0
if disp:
self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t'
f'rel.func.upd.= {rel_func_update:.5e}')
try:
while True:
if callback is not None:
callback(np.copy(x))
if norm_update <= update_norm_tol:
message = 'norm of update below tolerance'
break
elif rel_func_update <= rel_func_update_tol:
message = 'relative function value update below tolerance'
break
elif i >= maxiter:
message = 'maximum number of iterations reached'
break

update = compute_update_direction(x)
x = x - update

func_x, _ = energy_and_gradient(x)
if not np.isclose(old_func_x, 0.):
rel_func_update = abs((func_x - old_func_x) / old_func_x)
else:
rel_func_update = 0.
old_func_x = func_x
norm_update = np.linalg.norm(update)
i += 1
if disp:
self.logger.info(f'iter: {i:5d}\tf= {func_x:.5e}\t|update|= {norm_update:.5e}\t'
f'rel.func.upd.= {rel_func_update:.5e}')
except KeyboardInterrupt:
message = 'optimization stopped due to keyboard interrupt'
self.logger.warning('Optimization interrupted ...')

self.logger.info('Finished optimization ...')
result = {'x': x, 'nit': i, 'message': message}
return result

res = newton(initial_momenta, callback=save_current_state, **optimizer_options)
elif optimization_method == 'newton' and not landmarks_labeled:
raise NotImplementedError
else:
# use scipy optimizer for minimizing energy function
with self.logger.block("Perform landmark matching via geodesic shooting ..."):
res = optimize.minimize(energy_and_gradient, initial_momenta.flatten(),
method=optimization_method, jac=True, options=optimizer_options,
callback=save_current_state)

opt['initial_momenta'] = res['x'].reshape(input_landmarks.shape)
momenta_time_dependent, positions_time_dependent = self.integrate_forward_Hamiltonian(res['x'],
Expand Down Expand Up @@ -353,16 +422,17 @@ def rhs_d_positions_function(d_position, position, d_momentum, momentum):

ti_d_positions = self.time_integrator(rhs_d_positions_function, self.dt)

def rhs_d_momenta_function(d_momentum, position):
return - (d_momentum @ self.DK(position) @ position + position @ self.DK(position) @ d_momentum)
def rhs_d_momenta_function(d_momentum, momentum, position):
return - 0.5 * (d_momentum @ self.DK(position) @ momentum + momentum @ self.DK(position) @ d_momentum)

ti_d_momenta = self.time_integrator(rhs_d_momenta_function, self.dt)

for t in range(self.time_steps-1):
d_positions[t+1] = ti_d_positions.step(d_positions[t], additional_args={'position': positions[t],
'd_momentum': d_momenta[t],
'momentum': momenta[t]})
d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'position': positions[t]})
d_momenta[t+1] = ti_d_momenta.step(d_momenta[t], additional_args={'momentum': momenta[t],
'position': positions[t]})

return d_positions[-1], d_momenta[-1]

Expand Down Expand Up @@ -396,8 +466,8 @@ def get_vector_field(self, momenta, positions,
kernel=self.kernel)

for pos in np.ndindex(spatial_shape):
spatial_pos = mins + (maxs - mins) / np.array(spatial_shape) * np.array(pos)
vector_field[pos] = vf_func(spatial_pos) * np.array(spatial_shape)
spatial_pos = mins + (maxs - mins) * np.array(pos) / np.array(spatial_shape)
vector_field[pos] = vf_func(spatial_pos)

return vector_field

Expand Down Expand Up @@ -435,40 +505,7 @@ def compute_time_evolution_of_diffeomorphisms(self, initial_momenta, initial_pos
for t, (m, p) in enumerate(zip(momenta, positions)):
vector_fields[t] = self.get_vector_field(m, p, mins, maxs, spatial_shape)

flow = self.integrate_forward_flow(vector_fields)

return flow

def integrate_forward_flow(self, vector_fields):
"""Computes forward integration according to given vector fields.
Parameters
----------
vector_fields
`TimeDependentVectorField` containing the sequence of vector fields to integrate
in time.
Returns
-------
`VectorField` containing the flow at the final time.
"""
assert isinstance(vector_fields, TimeDependentVectorField)
assert vector_fields.time_steps == self.time_steps
spatial_shape = vector_fields[0].spatial_shape
# make identity grid
identity_grid = grid.coordinate_grid(spatial_shape)

# initial flow is the identity mapping
flow = identity_grid.copy()

def rhs_function(x, v):
return - sampler.sample(v, x, sampler_options=self.sampler_options)

ti = self.time_integrator(rhs_function, self.dt)

# perform forward integration
for v in vector_fields:
flow = ti.step(flow, additional_args={'v': v})
flow = vector_fields.integrate_backward(sampler_options=self.sampler_options)

return flow

Expand Down
14 changes: 14 additions & 0 deletions tests/test_landmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
import geodesic_shooting


def test_differential_computation():
input_landmarks = np.array([[5., 3.], [4., 2.], [1., 0.], [2., 3.]])

Check warning on line 7 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L7

F841: local variable 'input_landmarks' is assigned to but never used
target_landmarks = np.array([[6., 2.], [5., 1.], [1., -1.], [2.5, 2.]])

Check warning on line 8 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L8

F841: local variable 'target_landmarks' is assigned to but never used

gs = geodesic_shooting.LandmarkShooting()

w = np.zeros(gs.dim)
for i in range(gs.dim):
ei = np.zeros(gs.dim)
ei[i] = 1.
w[i] = ((gs.DK(position) @ ei) @ momentum).dot(momentum)

Check warning on line 16 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L16

F821: undefined name 'position'

Check warning on line 16 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L16

F821: undefined name 'momentum'

Check warning on line 16 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L16

F821: undefined name 'momentum'
assert np.allclose(w, (momentum.T @ gs.DK(position) @ momentum))

Check warning on line 17 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L17

F821: undefined name 'momentum'

Check warning on line 17 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L17

F821: undefined name 'position'

Check warning on line 17 in tests/test_landmark_matching.py

View workflow job for this annotation

GitHub Actions / Lintly/flake8

tests/test_landmark_matching.py#L17

F821: undefined name 'momentum'


def test_landmark_matching():
def compute_average_distance(target_landmarks, registered_landmarks):
dist = 0.
Expand Down

0 comments on commit 8cde6c8

Please sign in to comment.