-
Notifications
You must be signed in to change notification settings - Fork 3
/
diffusion_regularizer.py
115 lines (100 loc) · 4.21 KB
/
diffusion_regularizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Optional, Sequence, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
def spatial_gradient(x: torch.Tensor,
dim: int,
mode: str = 'forward') -> torch.Tensor:
"""Calculate gradients on single dimension of a tensor using central finite
difference. It moves the tensor along the dimension to calculate the approximate
gradient.
dx[i] = (x[i+1] - x[i-1]) / 2.
or forward/backward finite difference
dx[i] = x[i+1] - x[i]
Adopted from:
Project-MONAI (https://github.com/Project-MONAI/MONAI/blob/dev/monai/losses/deform.py)
Args:
x: the shape should be BCH(WD).
dim: dimension to calculate gradient along.
mode: flag deciding whether to use central or forward finite difference,
['forward','central']
Returns:
gradient_dx: the shape should be
"""
if mode not in ['forward', 'central']:
raise ValueError(
f'Unsupported finite difference method: {mode}, available options are ["forward", "central"].'
)
slice_all = slice(None)
slicing_s, slicing_e = [slice_all] * x.ndim, [slice_all] * x.ndim
if mode == 'central':
slicing_s[dim] = slice(2, None)
slicing_e[dim] = slice(None, -2)
return (x[slicing_s] - x[slicing_e]) / 2.0
elif mode == 'forward':
slicing_s[dim] = slice(1, None)
slicing_e[dim] = slice(None, -1)
return x[slicing_s] - x[slicing_e]
else:
raise ValueError(
f'Unsupported finite difference method: {mode}, available options are ["forward", "central"].'
)
@LOSSES.register_module('diffusion')
class GradientDiffusionLoss(nn.Module):
"""Calculate the diffusion regularizer (smoothness regularizer) on the spatial
gradients of displacement/velocity field."""
def __init__(self,
penalty: Union[int, str] = 'l1',
loss_mult: Optional[float] = None) -> None:
"""
Args:
penalty (str): flag decide l1/l2 norm of diffusion to compute
loss_mult (float, optional): loss multiplier depending on the downsize of displacement/velocity field, loss_mult = int_downsize
"""
super().__init__()
self.penalty = penalty
self.loss_mult = loss_mult
def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: the shape should be BCH(WD)
"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(
f'Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}'
)
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 4:
raise ValueError(
f'All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}'
)
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f'Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim - 2}'
)
# TODO: forward mode and central mode cause different result, the reason is still unknown
# Using forward mode to be consistent with voxelmorph paper
first_order_gradient = [
spatial_gradient(pred, dim, mode='forward')
for dim in range(2, pred.ndim)
]
loss = torch.tensor(0, dtype=torch.float32, device=pred.device)
for dim, g in enumerate(first_order_gradient):
if self.penalty == 'l1':
loss += torch.mean(torch.abs(first_order_gradient[dim]))
elif self.penalty == 'l2':
loss += torch.mean(first_order_gradient[dim]**2)
else:
raise ValueError(
f'Unsupported norm: {self.penalty}, available options are ["l1","l2"].'
)
if self.loss_mult is not None:
loss *= self.loss_mult
return loss / float(pred.ndim - 2)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += (f'(penalty=\'{self.penalty}\','
f'loss_mult={self.loss_mult})')
return repr_str