-
Notifications
You must be signed in to change notification settings - Fork 41
/
m_resunet.py
152 lines (116 loc) · 4.27 KB
/
m_resunet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
ResUNet++ architecture in Keras TensorFlow
"""
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
def squeeze_excite_block(inputs, ratio=8):
init = inputs
channel_axis = -1
filters = init.shape[channel_axis]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
x = Multiply()([init, se])
return x
def stem_block(x, n_filter, strides):
x_init = x
## Conv 1
x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(n_filter, (3, 3), padding="same")(x)
## Shortcut
s = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
s = BatchNormalization()(s)
## Add
x = Add()([x, s])
x = squeeze_excite_block(x)
return x
def resnet_block(x, n_filter, strides=1):
x_init = x
## Conv 1
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(n_filter, (3, 3), padding="same", strides=strides)(x)
## Conv 2
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(n_filter, (3, 3), padding="same", strides=1)(x)
## Shortcut
s = Conv2D(n_filter, (1, 1), padding="same", strides=strides)(x_init)
s = BatchNormalization()(s)
## Add
x = Add()([x, s])
x = squeeze_excite_block(x)
return x
def aspp_block(x, num_filters, rate_scale=1):
x1 = Conv2D(num_filters, (3, 3), dilation_rate=(6 * rate_scale, 6 * rate_scale), padding="same")(x)
x1 = BatchNormalization()(x1)
x2 = Conv2D(num_filters, (3, 3), dilation_rate=(12 * rate_scale, 12 * rate_scale), padding="same")(x)
x2 = BatchNormalization()(x2)
x3 = Conv2D(num_filters, (3, 3), dilation_rate=(18 * rate_scale, 18 * rate_scale), padding="same")(x)
x3 = BatchNormalization()(x3)
x4 = Conv2D(num_filters, (3, 3), padding="same")(x)
x4 = BatchNormalization()(x4)
y = Add()([x1, x2, x3, x4])
y = Conv2D(num_filters, (1, 1), padding="same")(y)
return y
def attetion_block(g, x):
"""
g: Output of Parallel Encoder block
x: Output of Previous Decoder block
"""
filters = x.shape[-1]
g_conv = BatchNormalization()(g)
g_conv = Activation("relu")(g_conv)
g_conv = Conv2D(filters, (3, 3), padding="same")(g_conv)
g_pool = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(g_conv)
x_conv = BatchNormalization()(x)
x_conv = Activation("relu")(x_conv)
x_conv = Conv2D(filters, (3, 3), padding="same")(x_conv)
gc_sum = Add()([g_pool, x_conv])
gc_conv = BatchNormalization()(gc_sum)
gc_conv = Activation("relu")(gc_conv)
gc_conv = Conv2D(filters, (3, 3), padding="same")(gc_conv)
gc_mul = Multiply()([gc_conv, x])
return gc_mul
class ResUnetPlusPlus:
def __init__(self, input_size=256):
self.input_size = input_size
def build_model(self):
n_filters = [16, 32, 64, 128, 256]
inputs = Input((self.input_size, self.input_size, 3))
c0 = inputs
c1 = stem_block(c0, n_filters[0], strides=1)
## Encoder
c2 = resnet_block(c1, n_filters[1], strides=2)
c3 = resnet_block(c2, n_filters[2], strides=2)
c4 = resnet_block(c3, n_filters[3], strides=2)
## Bridge
b1 = aspp_block(c4, n_filters[4])
## Decoder
d1 = attetion_block(c3, b1)
d1 = UpSampling2D((2, 2))(d1)
d1 = Concatenate()([d1, c3])
d1 = resnet_block(d1, n_filters[3])
d2 = attetion_block(c2, d1)
d2 = UpSampling2D((2, 2))(d2)
d2 = Concatenate()([d2, c2])
d2 = resnet_block(d2, n_filters[2])
d3 = attetion_block(c1, d2)
d3 = UpSampling2D((2, 2))(d3)
d3 = Concatenate()([d3, c1])
d3 = resnet_block(d3, n_filters[1])
## output
outputs = aspp_block(d3, n_filters[0])
outputs = Conv2D(1, (1, 1), padding="same")(outputs)
outputs = Activation("sigmoid")(outputs)
## Model
model = Model(inputs, outputs)
return model