16
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

U-Netで異常検知!

Last updated at Posted at 2019-10-01

kaggleなどのコンペでは異常検知の問題にU-Netが多く使われています。そこで自分もU-Netを異常検知の分野で実装してみました。

使用したデータ

MVtec-adのmetal-nutを使用しました。
このデータセットのテスト用ディレクトリには70枚の異常画像が含まれていますが、そのうち8割をトレーニング用画像として利用しました。つまり教師あり学習となります。(MVtec-adは元々、教師なし学習のためのデータセット)

U-Netとは

U-Netとは画像セグメンテーションのために開発されたFCNモデルです。U-Netはセグメンテーションが得意なので異常箇所の可視化が簡単にできることなどがメリットとしてあります。もともと医療画像用に提案されたモデルですが、それを異常検知の分野に応用させています。
U-Netの構造は以下のようになっています。
image.png
kerasでの実装例がこちらです。


def build_model(input_shape):
    inputs = Input(input_shape)

    c1 = Conv2D(8, (3, 3), activation='elu', padding='same') (inputs)
    c1 = Conv2D(8, (3, 3), activation='elu', padding='same') (c1)
    p1 = MaxPooling2D((2, 2)) (c1)

    c2 = Conv2D(16, (3, 3), activation='elu', padding='same') (p1)
    c2 = Conv2D(16, (3, 3), activation='elu', padding='same') (c2)
    p2 = MaxPooling2D((2, 2)) (c2)

    c3 = Conv2D(32, (3, 3), activation='elu', padding='same') (p2)
    c3 = Conv2D(32, (3, 3), activation='elu', padding='same') (c3)
    p3 = MaxPooling2D((2, 2)) (c3)

    c4 = Conv2D(64, (3, 3), activation='elu', padding='same') (p3)
    c4 = Conv2D(64, (3, 3), activation='elu', padding='same') (c4)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

    c5 = Conv2D(64, (3, 3), activation='elu', padding='same') (p4)
    c5 = Conv2D(64, (3, 3), activation='elu', padding='same') (c5)
    p5 = MaxPooling2D(pool_size=(2, 2)) (c5)

    c55 = Conv2D(128, (3, 3), activation='elu', padding='same') (p5)
    c55 = Conv2D(128, (3, 3), activation='elu', padding='same') (c55)

    u6 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c55)
    u6 = concatenate([u6, c5])
    c6 = Conv2D(64, (3, 3), activation='elu', padding='same') (u6)
    c6 = Conv2D(64, (3, 3), activation='elu', padding='same') (c6)

    u71 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c6)
    u71 = concatenate([u71, c4])
    c71 = Conv2D(32, (3, 3), activation='elu', padding='same') (u71)
    c61 = Conv2D(32, (3, 3), activation='elu', padding='same') (c71)

    u7 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c61)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(32, (3, 3), activation='elu', padding='same') (u7)
    c7 = Conv2D(32, (3, 3), activation='elu', padding='same') (c7)

    u8 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(16, (3, 3), activation='elu', padding='same') (u8)
    c8 = Conv2D(16, (3, 3), activation='elu', padding='same') (c8)

    u9 = Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(8, (3, 3), activation='elu', padding='same') (u9)
    c9 = Conv2D(8, (3, 3), activation='elu', padding='same') (c9)

    outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer='adam', loss=bce_dice_loss, metrics=[dice_coef])
    
    return model

###損失関数
損失関数にはDISE係数を改良したbce-dise-lossを用いました。このサイトで詳しく解説されています。
DISE係数とは画像の類似度を測る関数です。
image.png
これとbce(binary cross entropy)を合わせたものがbce-dise-lossです。

##結果
トレーニング画像112枚で学習させた結果が下の図です。
上段が元画像、中段が予測画像、下段が教師データに使用したマスク画像です。
image.png
二枚目のように誤検知している部分もありますが、ある程度うまく検出できています。
元画像に予測したマスクをかけてみると下のようになります。
image.png
うまく可視化できています。

##まとめ
U-Netを使って異常検知を行うことができた。
今回は画像水増しなどを行っていない。画像水増しを行うことで誤検知の抑制やより少ない教師データでの学習が可能になることが期待できる。
次は転移学習を利用したU-Netを試してみたい。

##最後に
今回がQiita初投稿です。記事の改善すべきところなど、ご指摘頂ければ幸いです。

16
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?