LoginSignup
1
2

More than 3 years have passed since last update.

KerasのSpatialDropout2Dの動作具体例

Posted at

概要

SpatialDropout2Dというのがあるので画像認識で利用を検討したのですが、字面からイメージしたのと違った動作でした。
先行記事もありますが、画像で使う場合の具体例が出ていないので、こちらでまとめます。

具体例

この画像を入力したとして、

colorbar.png

これが普通のDropout

dropout.png

これがSpatialDropout2Dです。画像全体で特定のチャンネルをDropする動作です。

spatial_dropout.png

私が[Spatial]という字面から想像したのはこういうのでした…。これはdropoutのnoise_shape引数をうまく与えれば実現することができます。

point_dropout.png

再現用コード

check_dropout.py
import numpy as np
import cv2
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K

src = cv2.imread("colorbar.png")
height, width = src.shape[:2]

def point_dropout_function(x, drop_rate):
    s = K.shape(x)
    return K.dropout(x, drop_rate, (s[0], s[1], s[2], 1))

K.set_learning_phase(1)  # 実行時にDropout有効化するために必要
x = Input((height, width, 3))
y_dropout = Dropout(0.2)(x)
y_spatial = SpatialDropout2D(0.2)(x)
y_point = Lambda(lambda x: point_dropout_function(x, 0.2))(x)
f = K.function([x], [y_dropout, y_spatial, y_point])

for i in range(10):
    dst_dropout, dst_spatial, dst_point = f([src[None, :, :, :]])

    cv2.imshow("src", src)
    cv2.imshow("dst_dropout", dst_dropout[0])
    cv2.imshow("dst_spatial", dst_spatial[0])
    cv2.imshow("dst_point", dst_point[0])
    cv2.waitKey()

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2