Keras

U-netを短く書く

こちらでベタ書きしていてやばいと思ったので書きました。
https://www.kaggle.com/keegil/keras-u-net-starter-lb-0-277

準備部分

layers
from keras.layers.convolutional import Conv2D

def Conv2D16(s):
    return Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)

def _Conv2D(s,size,dsize=3):
    return Conv2D(size, (dsize, dsize), activation='elu', kernel_initializer='he_normal', padding='same')(s)

def CDCP(s,size,dsize=3,droprate=0.1,withpooling=True):
    c =_Conv2D(s,size,dsize)
    c = Dropout(droprate)(c)
    c =_Conv2D(c,size,dsize)
    if(withpooling):
        p=MaxPooling2D((2, 2)) (c)
        return p,c
    else:
        return c

def Ulayer(s,t,size,dsize=3,droprate=0.2):
    u = Conv2DTranspose(size, (2, 2), strides=(2, 2), padding='same') (s)
    u = concatenate([u, t])
    c = _Conv2D(u,size, dsize)
    c = Dropout(droprate)(c)
    return _Conv2D(c,size,dsize)

ネットワークを作る部分
下の方はやりすぎ感がある。

gennet
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS=1280,720,3

inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

def genUnet1(inputs):
    s = Lambda(lambda x: x / 255) (inputs)
    a0,c0=CDCP(s,16,3,0.1)
    a1,c1=CDCP(a0,32,3,0.1)
    a2,c2=CDCP(a1,64,3,0.2)
    a3,c3=CDCP(a2,128,3,0.2)
    a4=CDCP(a3,256,3,0.2,False)

    a5=Ulayer(a4,c3,128)
    a6=Ulayer(a5,c2,64)
    a7=Ulayer(a6,c1,32)
    a8=Ulayer(a7,c0,16)
    a9=Conv2D(1, (1, 1), activation='sigmoid')(a8)
    return Model(inputs=[inputs],outputs=[a9])

import math
def genUnet(inputs,dropoutrate):
    width,height,channelnum=[a.value for a in inputs.shape]
    s = Lambda(lambda x: x / 255) (inputs)

    ls=[[s]]
    Nmax=int(math.log2(min(width,height)))
    startwidth=4

    for i in range(Nmax-startwidth):
        ni=i+startwidth
        if(i==0):
            ls.append(CDCP(ls[-1]   ,2**ni, channelnum,dropoutrate))
        else:
            ls.append(CDCP(ls[-1][0],2**ni, channelnum,dropoutrate))

    ls.append(CDCP(ls[-1][0],2**Nmax, channelnum,dropoutrate,False))

    for i in range(Nmax-1,startwidth-1,-1):
        ls.append(Ulayer(ls[-1][0],ls[i-startwidth][1], 2**i, channelnum,dropoutrate,False))

    ls.append(Conv2D(1, (1, 1), activation='sigmoid')(ls[-1]))
    return Model(inputs=[inputs],outputs=[ls[-1]])

コンパイル

compile
model=genUnet(inputs)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou])
model.summary()
fit
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model-unet.h5', verbose=1, save_best_only=True)
results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=50, 
                    callbacks=[earlystopper, checkpointer])

結果
https://gist.github.com/xiangze/d42d7c8e5ae967fba388d5c639365f3a

ちゃんとできてるっぽい。