1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

U-netを短く書く

Last updated at Posted at 2018-03-25

こちらでベタ書きしていてやばいと思ったので書きました。
https://www.kaggle.com/keegil/keras-u-net-starter-lb-0-277

準備部分

layers
from keras.layers.convolutional import Conv2D

def Conv2D16(s):
    return Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)

def _Conv2D(s,size,dsize=3):
    return Conv2D(size, (dsize, dsize), activation='elu', kernel_initializer='he_normal', padding='same')(s)
    
def CDCP(s,size,dsize=3,droprate=0.1,withpooling=True):
    c =_Conv2D(s,size,dsize)
    c = Dropout(droprate)(c)
    c =_Conv2D(c,size,dsize)
    if(withpooling):
        p=MaxPooling2D((2, 2)) (c)
        return p,c
    else:
        return c
    
def Ulayer(s,t,size,dsize=3,droprate=0.2):
    u = Conv2DTranspose(size, (2, 2), strides=(2, 2), padding='same') (s)
    u = concatenate([u, t])
    c = _Conv2D(u,size, dsize)
    c = Dropout(droprate)(c)
    return _Conv2D(c,size,dsize)

ネットワークを作る部分
下の方はやりすぎ感がある。

gennet
IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS=1280,720,3

inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

def genUnet1(inputs):
    s = Lambda(lambda x: x / 255) (inputs)
    a0,c0=CDCP(s,16,3,0.1)
    a1,c1=CDCP(a0,32,3,0.1)
    a2,c2=CDCP(a1,64,3,0.2)
    a3,c3=CDCP(a2,128,3,0.2)
    a4=CDCP(a3,256,3,0.2,False)

    a5=Ulayer(a4,c3,128)
    a6=Ulayer(a5,c2,64)
    a7=Ulayer(a6,c1,32)
    a8=Ulayer(a7,c0,16)
    a9=Conv2D(1, (1, 1), activation='sigmoid')(a8)
    return Model(inputs=[inputs],outputs=[a9])

import math
def genUnet(inputs,dropoutrate):
    width,height,channelnum=[a.value for a in inputs.shape]
    s = Lambda(lambda x: x / 255) (inputs)

    ls=[[s]]
    Nmax=int(math.log2(min(width,height)))
    startwidth=4

    for i in range(Nmax-startwidth):
        ni=i+startwidth
        if(i==0):
            ls.append(CDCP(ls[-1]   ,2**ni, channelnum,dropoutrate))
        else:
            ls.append(CDCP(ls[-1][0],2**ni, channelnum,dropoutrate))

    ls.append(CDCP(ls[-1][0],2**Nmax, channelnum,dropoutrate,False))

    for i in range(Nmax-1,startwidth-1,-1):
        ls.append(Ulayer(ls[-1][0],ls[i-startwidth][1], 2**i, channelnum,dropoutrate,False))

    ls.append(Conv2D(1, (1, 1), activation='sigmoid')(ls[-1]))
    return Model(inputs=[inputs],outputs=[ls[-1]])

コンパイル

compile
model=genUnet(inputs)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[mean_iou])
model.summary()
fit

earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model-unet.h5', verbose=1, save_best_only=True)
results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=50, 
                    callbacks=[earlystopper, checkpointer])

結果
https://gist.github.com/xiangze/d42d7c8e5ae967fba388d5c639365f3a

ちゃんとできてるっぽい。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?