U-netを短く書く

More than 1 year has passed since last update.

こちらでベタ書きしていてやばいと思ったので書きました。

https://www.kaggle.com/keegil/keras-u-net-starter-lb-0-277

準備部分


layers

from keras.layers.convolutional import Conv2D

def Conv2D16(s):
return Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)

def _Conv2D(s,size,dsize=3):
return Conv2D(size, (dsize, dsize), activation='elu', kernel_initializer='he_normal', padding='same')(s)

def CDCP(s,size,dsize=3,droprate=0.1,withpooling=True):
c =_Conv2D(s,size,dsize)
c = Dropout(droprate)(c)
c =_Conv2D(c,size,dsize)
if(withpooling):
p=MaxPooling2D((2, 2)) (c)
return p,c
else:
return c

def Ulayer(s,t,size,dsize=3,droprate=0.2):
u = Conv2DTranspose(size, (2, 2), strides=(2, 2), padding='same') (s)
u = concatenate([u, t])
c = _Conv2D(u,size, dsize)
c = Dropout(droprate)(c)
return _Conv2D(c,size,dsize)


ネットワークを作る部分

下の方はやりすぎ感がある。


gennet

IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS=1280,720,3

inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

def genUnet1(inputs):
s = Lambda(lambda x: x / 255) (inputs)
a0,c0=CDCP(s,16,3,0.1)
a1,c1=CDCP(a0,32,3,0.1)
a2,c2=CDCP(a1,64,3,0.2)
a3,c3=CDCP(a2,128,3,0.2)
a4=CDCP(a3,256,3,0.2,False)

a5=Ulayer(a4,c3,128)
a6=Ulayer(a5,c2,64)
a7=Ulayer(a6,c1,32)
a8=Ulayer(a7,c0,16)
a9=Conv2D(1, (1, 1), activation='sigmoid')(a8)
return Model(inputs=[inputs],outputs=[a9])

import math
def genUnet(inputs,dropoutrate):
width,height,channelnum=[a.value for a in inputs.shape]
s = Lambda(lambda x: x / 255) (inputs)

ls=[[s]]
Nmax=int(math.log2(min(width,height)))
startwidth=4

for i in range(Nmax-startwidth):
ni=i+startwidth
if(i==0):
ls.append(CDCP(ls[-1] ,2**ni, channelnum,dropoutrate))
else:
ls.append(CDCP(ls[-1][0],2**ni, channelnum,dropoutrate))

ls.append(CDCP(ls[-1][0],2**Nmax, channelnum,dropoutrate,False))

for i in range(Nmax-1,startwidth-1,-1):
ls.append(Ulayer(ls[-1][0],ls[i-startwidth][1], 2**i, channelnum,dropoutrate,False))

ls.append(Conv2D(1, (1, 1), activation='sigmoid')(ls[-1]))
return Model(inputs=[inputs],outputs=[ls[-1]])


コンパイル


compile

model=genUnet(inputs)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou])
model.summary()


fit


earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model-unet.h5', verbose=1, save_best_only=True)
results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=50,
callbacks=[earlystopper, checkpointer])

結果

https://gist.github.com/xiangze/d42d7c8e5ae967fba388d5c639365f3a

ちゃんとできてるっぽい。