forked from lppllppl920/DenseDescriptorLearning-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
268 lines (219 loc) · 12.2 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
'''
Author: Xingtong Liu, Yiping Zheng, Benjamin Killeen, Masaru Ishii, Gregory D. Hager, Russell H. Taylor, and Mathias Unberath
Copyright (C) 2020 Johns Hopkins University - All Rights Reserved
You may use, distribute and modify this code under the
terms of the GNU GENERAL PUBLIC LICENSE Version 3 license for non-commercial usage.
You should have received a copy of the GNU GENERAL PUBLIC LICENSE Version 3 license with
this file. If not, please write to: [email protected] or [email protected]
'''
import torch.nn as nn
import torch
# Removed dropout and changed the transition up layers in the original implementation
# to mitigate the grid patterns of the network output
class DenseLayer(nn.Sequential):
def __init__(self, in_channels, growth_rate):
super(DenseLayer, self).__init__()
self.add_module('norm', nn.BatchNorm2d(in_channels))
self.add_module('relu', nn.ReLU(True))
self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
stride=1, padding=1, bias=True))
def forward(self, x):
return super(DenseLayer, self).forward(x)
class DenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
super(DenseBlock, self).__init__()
self.upsample = upsample
self.layers = nn.ModuleList([DenseLayer(
in_channels + i * growth_rate, growth_rate)
for i in range(n_layers)])
def forward(self, x):
if self.upsample:
new_features = []
# we pass all previous activations into each dense layer normally
# But we only store each dense layer's output in the new_features array
for layer in self.layers:
out = layer(x)
x = torch.cat([x, out], 1)
new_features.append(out)
return torch.cat(new_features, 1)
else:
for layer in self.layers:
out = layer(x)
x = torch.cat([x, out], 1) # 1 = channel axis
return x
class TransitionDown(nn.Sequential):
def __init__(self, in_channels):
super(TransitionDown, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_features=in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(in_channels, in_channels,
kernel_size=1, stride=1,
padding=0, bias=True))
self.add_module('maxpool', nn.MaxPool2d(2))
def forward(self, x):
return super(TransitionDown, self).forward(x)
class TransitionUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransitionUp, self).__init__()
self.convTrans = nn.Sequential(nn.Upsample(mode='nearest', scale_factor=2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
def forward(self, x, skip):
out = self.convTrans(x)
out = center_crop_(out, skip.size(2), skip.size(3))
out = torch.cat([out, skip], 1)
return out
class Bottleneck(nn.Sequential):
def __init__(self, in_channels, growth_rate, n_layers):
super(Bottleneck, self).__init__()
self.add_module('bottleneck', DenseBlock(
in_channels, growth_rate, n_layers, upsample=True))
def forward(self, x):
return super(Bottleneck, self).forward(x)
def center_crop_(layer, max_height, max_width):
_, _, h, w = layer.size()
xy1 = (w - max_width) // 2
xy2 = (h - max_height) // 2
return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
class FCDenseNet(nn.Module):
def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5),
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
growth_rate=16, out_chans_first_conv=48, feature_length=256):
super(FCDenseNet, self).__init__()
self.down_blocks = down_blocks
self.up_blocks = up_blocks
cur_channels_count = 0
skip_connection_channel_counts = []
# First Convolution
self.add_module('firstconv', nn.Conv2d(in_channels=in_channels,
out_channels=out_chans_first_conv, kernel_size=3,
stride=1, padding=1, bias=True))
cur_channels_count = out_chans_first_conv
#####################
# Downsampling path #
#####################
self.denseBlocksDown = nn.ModuleList([])
self.transDownBlocks = nn.ModuleList([])
for i in range(len(down_blocks)):
self.denseBlocksDown.append(
DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
cur_channels_count += (growth_rate * down_blocks[i])
skip_connection_channel_counts.insert(0, cur_channels_count)
self.transDownBlocks.append(TransitionDown(cur_channels_count))
#####################
# Bottleneck #
#####################
self.add_module('bottleneck', Bottleneck(cur_channels_count,
growth_rate, bottleneck_layers))
prev_block_channels = growth_rate * bottleneck_layers
cur_channels_count += prev_block_channels
#######################
# Upsampling path #
#######################
self.transUpBlocks = nn.ModuleList([])
self.denseBlocksUp = nn.ModuleList([])
for i in range(len(up_blocks) - 1):
self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels))
cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]
self.denseBlocksUp.append(DenseBlock(
cur_channels_count, growth_rate, up_blocks[i],
upsample=True))
prev_block_channels = growth_rate * up_blocks[i]
cur_channels_count += prev_block_channels
# Final DenseBlock
self.transUpBlocks.append(TransitionUp(
prev_block_channels, prev_block_channels))
cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1]
self.denseBlocksUp.append(DenseBlock(
cur_channels_count, growth_rate, up_blocks[-1],
upsample=False))
cur_channels_count += growth_rate * up_blocks[-1]
self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
out_channels=feature_length, kernel_size=1, stride=1,
padding=0, bias=True)
def forward(self, x):
out = self.firstconv(x)
skip_connections = []
for i in range(len(self.down_blocks)):
out = self.denseBlocksDown[i](out)
skip_connections.append(out)
out = self.transDownBlocks[i](out)
out = self.bottleneck(out)
for i in range(len(self.up_blocks)):
skip = skip_connections.pop()
out = self.transUpBlocks[i](out, skip)
out = self.denseBlocksUp[i](out)
out = self.finalConv(out)
out = out / torch.norm(out, dim=1, keepdim=True)
return out
class FeatureResponseGenerator(nn.Module):
def __init__(self, scale=20.0, threshold=0.9):
super(FeatureResponseGenerator, self).__init__()
self.scale = scale
self.threshold = threshold
def forward(self, x):
source_feature_map, target_feature_map, source_feature_1D_locations, boundaries = x
# source_feature_map: B x C x H x W
# source_feature_1D_locations: B x Sampling_size x 1
batch_size, channel, height, width = source_feature_map.shape
_, sampling_size, _ = source_feature_1D_locations.shape
# B x C x Sampling_size
source_feature_1D_locations = source_feature_1D_locations.view(batch_size, 1,
sampling_size).expand(-1, channel, -1)
# Extend 1D locations to B x C x Sampling_size
# B x C x Sampling_size
sampled_feature_vectors = torch.gather(source_feature_map.view(batch_size, channel, height * width), 2,
source_feature_1D_locations.long())
sampled_feature_vectors = sampled_feature_vectors.view(batch_size, channel, sampling_size, 1,
1).permute(0, 2, 1, 3, 4).view(batch_size,
sampling_size,
channel,
1, 1)
# Do convolution on target_feature_map with the sampled_feature_vectors as the kernels
# We use the sampled feature vectors in a convolution operation where BC is the input channel dim and
# Sampling_size as the output channel dim.
temp = [None for _ in range(batch_size)]
for i in range(batch_size):
temp[i] = torch.nn.functional.conv2d(input=target_feature_map[i].view(1, channel, height, width),
weight=sampled_feature_vectors[i].view(sampling_size, channel,
1,
1),
padding=0)
# B x Sampling_size x H x W
cosine_distance_map = 0.5 * torch.cat(temp, dim=0) + 0.5
# Normalized cosine distance map
# B x Sampling_size x H x W
cosine_distance_map = torch.exp(self.scale * (cosine_distance_map - self.threshold))
cosine_distance_map = cosine_distance_map / torch.sum(cosine_distance_map, dim=(2, 3), keepdim=True)
return cosine_distance_map
class FeatureResponseGeneratorNoSoftThresholding(nn.Module):
def __init__(self):
super(FeatureResponseGeneratorNoSoftThresholding, self).__init__()
def forward(self, x):
source_feature_map, target_feature_map, source_feature_1D_locations, boundaries = x
# source_feature_map: B x C x H x W
# source_feature_1D_locations: B x Sampling_size x 1
batch_size, channel, height, width = source_feature_map.shape
_, sampling_size, _ = source_feature_1D_locations.shape
# B x C x Sampling_size
source_feature_1D_locations = source_feature_1D_locations.view(batch_size, 1,
sampling_size).expand(-1, channel, -1)
# Extend 1D locations to B x C x Sampling_size
# B x C x Sampling_size
sampled_feature_vectors = torch.gather(source_feature_map.view(batch_size, channel, height * width), 2,
source_feature_1D_locations.long())
sampled_feature_vectors = sampled_feature_vectors.view(batch_size, channel, sampling_size, 1,
1).permute(0, 2, 1, 3, 4).view(batch_size,
sampling_size,
channel,
1, 1)
# Do convolution on target_feature_map with the sampled_feature_vectors as the kernels
# We use the sampled feature vectors in a convolution operation where BC is the input channel dim and
# Sampling_size as the output channel dim.
temp = [None for _ in range(batch_size)]
for i in range(batch_size):
temp[i] = torch.nn.functional.conv2d(input=target_feature_map[i].view(1, channel, height, width),
weight=sampled_feature_vectors[i].view(sampling_size, channel,
1, 1), padding=0)
# B x Sampling_size x H x W
cosine_distance_map = torch.cat(temp, dim=0)
return cosine_distance_map