From db72d4822dbe52a6fc6bf0c379b253b08ba07018 Mon Sep 17 00:00:00 2001 From: Michael Perel Date: Fri, 18 May 2018 10:22:23 -0400 Subject: [PATCH] Updated U-Net implementation to Keras 2 --- unet.py | 245 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 135 insertions(+), 110 deletions(-) diff --git a/unet.py b/unet.py index 9a2233e49..c680c8ba1 100644 --- a/unet.py +++ b/unet.py @@ -2,7 +2,7 @@ #os.environ["CUDA_VISIBLE_DEVICES"] = "0" import numpy as np from keras.models import * -from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D +from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D, Concatenate from keras.optimizers import * from keras.callbacks import ModelCheckpoint, LearningRateScheduler from keras import backend as keras @@ -24,129 +24,154 @@ def load_data(self): def get_unet(self): - inputs = Input((self.img_rows, self.img_cols,1)) - - ''' - unet with crop(because padding = valid) - - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs) - print "conv1 shape:",conv1.shape - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1) - print "conv1 shape:",conv1.shape - crop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1) - print "crop1 shape:",crop1.shape + inputs = Input((self.img_rows, self.img_cols,1)) + + conv1 = Conv2D(64, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(inputs) + conv1 = Conv2D(64, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv1) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) - print "pool1 shape:",pool1.shape - - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1) - print "conv2 shape:",conv2.shape - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2) - print "conv2 shape:",conv2.shape - crop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2) - print "crop2 shape:",crop2.shape - pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) - print "pool2 shape:",pool2.shape - - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2) - print "conv3 shape:",conv3.shape - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3) - print "conv3 shape:",conv3.shape - crop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3) - print "crop3 shape:",crop3.shape - pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) - print "pool3 shape:",pool3.shape - - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3) - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4) - drop4 = Dropout(0.5)(conv4) - crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4) - pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) - - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4) - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5) - drop5 = Dropout(0.5)(conv5) - up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) - merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6) - - up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) - merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7) - - up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) - merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8) - - up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) - merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9) - conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9) - ''' - - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) - print "conv1 shape:",conv1.shape - conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) - print "conv1 shape:",conv1.shape - pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) - print "pool1 shape:",pool1.shape - - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) - print "conv2 shape:",conv2.shape - conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) - print "conv2 shape:",conv2.shape + conv2 = Conv2D(128, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool1) + conv2 = Conv2D(128, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv2) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) - print "pool2 shape:",pool2.shape - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) - print "conv3 shape:",conv3.shape - conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) - print "conv3 shape:",conv3.shape + conv3 = Conv2D(256, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool2) + conv3 = Conv2D(256, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv3) pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) - print "pool3 shape:",pool3.shape - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) - conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) + conv4 = Conv2D(512, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool3) + conv4 = Conv2D(512, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv4) drop4 = Dropout(0.5)(conv4) pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) - conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) + conv5 = Conv2D(1024, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool4) + conv5 = Conv2D(1024, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv5) drop5 = Dropout(0.5)(conv5) - up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) - merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) - conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) - - up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) - merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) - conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) - - up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) - merge8 = merge([conv2,up8], mode = 'concat', concat_axis = 3) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) - conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) - - up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) - merge9 = merge([conv1,up9], mode = 'concat', concat_axis = 3) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) - conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) - conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) - conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) - - model = Model(input = inputs, output = conv10) - - model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) + up6 = UpSampling2D(size=(2, 2))(drop5) + up6 = Conv2D(512, + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(up6) + merge6 = Concatenate(axis=3)([drop4, up6]) + conv6 = Conv2D(512, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge6) + conv6 = Conv2D(512, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv6) + + up7 = UpSampling2D(size=(2, 2))(conv6) + up7 = Conv2D(256, + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(up7) + merge7 = Concatenate(axis=3)([conv3, up7]) + conv7 = Conv2D(256, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge7) + conv7 = Conv2D(256, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv7) + + up8 = UpSampling2D(size=(2, 2))(conv7) + up8 = Conv2D(128, + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(up8) + merge8 = Concatenate(axis=3)([conv2, up8]) + conv8 = Conv2D(128, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge8) + conv8 = Conv2D(128, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv8) + + up9 = UpSampling2D(size=(2, 2))(conv8) + up9 = Conv2D(64, + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(up9) + merge9 = Concatenate(axis=3)([conv1, up9]) + conv9 = Conv2D(64, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge9) + conv9 = Conv2D(64, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + conv9 = Conv2D(2, + 3, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + conv10 = Conv2D(1, 1, activation='sigmoid')(conv9) + + model = Model(inputs=inputs, outputs=conv10) + model.compile(optimizer=Adam(lr=1e-4), + loss='binary_crossentropy', + metrics=['accuracy']) return model - def train(self): print("loading data")