Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated U-Net implementation to Keras 2 #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 135 additions & 110 deletions unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
from keras.models import *
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D, Concatenate
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
Expand All @@ -24,129 +24,154 @@ def load_data(self):

def get_unet(self):

inputs = Input((self.img_rows, self.img_cols,1))

'''
unet with crop(because padding = valid)

conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)
print "conv1 shape:",conv1.shape
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)
print "conv1 shape:",conv1.shape
crop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)
print "crop1 shape:",crop1.shape
inputs = Input((self.img_rows, self.img_cols,1))

conv1 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(inputs)
conv1 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print "pool1 shape:",pool1.shape

conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)
print "conv2 shape:",conv2.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)
print "conv2 shape:",conv2.shape
crop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)
print "crop2 shape:",crop2.shape
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print "pool2 shape:",pool2.shape

conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)
print "conv3 shape:",conv3.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)
print "conv3 shape:",conv3.shape
crop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)
print "crop3 shape:",crop3.shape
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print "pool3 shape:",pool3.shape

conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)
'''

conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
print "conv1 shape:",conv1.shape
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
print "conv1 shape:",conv1.shape
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
print "pool1 shape:",pool1.shape

conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
print "conv2 shape:",conv2.shape
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
print "conv2 shape:",conv2.shape
conv2 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool1)
conv2 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
print "pool2 shape:",pool2.shape

conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
print "conv3 shape:",conv3.shape
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
print "conv3 shape:",conv3.shape
conv3 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool2)
conv3 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
print "pool3 shape:",pool3.shape

conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
conv4 = Conv2D(512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool3)
conv4 = Conv2D(512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
conv5 = Conv2D(1024,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool4)
conv5 = Conv2D(1024,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)

up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = merge([conv2,up8], mode = 'concat', concat_axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = merge([conv1,up9], mode = 'concat', concat_axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)

model = Model(input = inputs, output = conv10)

model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
up6 = UpSampling2D(size=(2, 2))(drop5)
up6 = Conv2D(512,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(up6)
merge6 = Concatenate(axis=3)([drop4, up6])
conv6 = Conv2D(512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge6)
conv6 = Conv2D(512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv6)

up7 = UpSampling2D(size=(2, 2))(conv6)
up7 = Conv2D(256,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(up7)
merge7 = Concatenate(axis=3)([conv3, up7])
conv7 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge7)
conv7 = Conv2D(256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv7)

up8 = UpSampling2D(size=(2, 2))(conv7)
up8 = Conv2D(128,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(up8)
merge8 = Concatenate(axis=3)([conv2, up8])
conv8 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge8)
conv8 = Conv2D(128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv8)

up9 = UpSampling2D(size=(2, 2))(conv8)
up9 = Conv2D(64,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(up9)
merge9 = Concatenate(axis=3)([conv1, up9])
conv9 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge9)
conv9 = Conv2D(64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv9 = Conv2D(2,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

model = Model(inputs=inputs, outputs=conv10)
model.compile(optimizer=Adam(lr=1e-4),
loss='binary_crossentropy',
metrics=['accuracy'])

return model


def train(self):

print("loading data")
Expand Down