38
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

深層学習/im2colの実装の工夫に驚いた件

Last updated at Posted at 2020-02-06

#1.はじめに
 im2col4次元配列 (ミニバッチサイズ、チャンネル数、縦幅、横幅)を、行列 (画像枚数*畳み込み回数、フィルターの要素数)に変換し、フィルター行列(フィルターの要素数、フィルターの枚数)に変換することで、畳み込み演算をただの行列積の計算にし、処理速度を劇的に改善するアルゴリズムです。

 im2colは有名なアルゴリズムで、ゼロから作る Deep Learningリポジトリ(common/util.py)に載っているので今回コードを見たところ、実装の工夫に驚いたので、備忘録として残します。

#2.im2col関数

def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad_h - filter_h)//stride_h + 1
    out_w = (W + 2*pad_w - filter_w)//stride_w + 1
    
    img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride_h*out_h
        for x in range(filter_w):
            x_max = x + stride_w*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col, out_h, out_w

 ゼロから作る Deep Learning のリポジトリにある、im2col関数です(一部表現を補足しています)。
 この関数の原理を具体的な例で説明します。input_dataのシェイプは(1,1,4,4)、フィルターは(2,2)、ストライド(1,1)、パディングはなし、という設定にすると、

スクリーンショット 2020-02-06 15.23.07.png
 スライシングを9回行うと、9行×4列の行列が出来上がりますので、これにフィルターの行列との内積をとれば、一気に畳み込み計算が完了となるわけです。

 ただし、実装はこれをそのまま行っておらず驚きの工夫がされています。それをこれから見て行きます。

def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):

    # input_dataから、バッチサイズN, チャンネル数C, 画像高さH, 画像幅W を取得
    N, C, H, W = input_data.shape

    # 畳み込み処理後の出力の高さ out_h, 幅 out_w を計算
    out_h = (H + 2*pad_h - filter_h)//stride_h + 1
    out_w = (W + 2*pad_w - filter_w)//stride_w + 1
    
    # 画像のパディング(高さ方向 pad_h, 幅方向 pad_w)
    img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')

    # 計算結果を保存するためのゼロ行列の作成
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

 コードの冒頭部分です。np.padの詳細を知りたい場合はこちらを参照して下さい。後は、特に目を引く部分はないと思います。問題は、この後です。

    for y in range(filter_h):

        # 縦方向のスライシング範囲の最大値y_maxを求める
        y_max = y + stride_h*out_h

        for x in range(filter_w):

            # 横方向のスライシング範囲の最大値x_maxを求める
            x_max = x + stride_w*out_w

            # スライシング結果をゼロ行列に格納
            # y から y_max まで stride_h 間隔でスライシング
            # x から x_max まで stride_w 間隔でスライシング
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]
    
    # col の各行列の要素を並べ替え、リシェイプ
    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col, out_h, out_w

 先程同様、input_dataのシェイプは(1,1,4,4)、フィルターは(2,2)、ストライド(1,1)、パディングはなし、という設定とすると、ここはどういう動きになるかと言うと、

 filter_h = 2, filter_w = 2 なので、2重のforループは4回しか回りません。そして、スライシングの範囲は y:y+filter_h:stride_h, x:x+filter_w:stride_w ではなく、もっと範囲は大きくy:y_max:stride_h, x:x_max:stride_wなのです。

 一体何をやっているか簡単に言うと、
スクリーンショット 2020-02-06 16.15.06.png
 何これ、凄い!!! 9回のスライシングが4回になっています。処理時間の掛かる for文 を出来るだけ少なくする工夫がされているわけです。

 ストライドが2だとおかしくならない?とか思いますよね。試しにやってみますね。

スクリーンショット 2020-02-06 16.18.43.png
 凄いです。 ちゃんとやれてます!

 つまり、ファイルターの要素数分だけ forループを回せば、必要な計算が出来てしまうわけです。これ、ピクセルサイズが大きくなると絶大な効果を発揮することは、簡単にお分かり頂けると思います。
 
 なお、コードの**col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)**の詳細については、こちらを参照して下さい。

 im2colのアルゴリズムを聞いた時、それは凄いアイディアだと驚いたのですが、その実装を見て再び驚きました。

 最後に、畳み込み演算全体のサンプルコードを書いておきます。

import numpy as np

def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad_h - filter_h)//stride_h + 1
    out_w = (W + 2*pad_w - filter_w)//stride_w + 1
    
    img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride_h*out_h
        for x in range(filter_w):
            x_max = x + stride_w*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col, out_h, out_w 


# 畳み込み演算する画像 ( img.shape = (1, 1, 4, 4) )
img =np.array([[[
    [0, 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 0, 1],
    [2, 3, 4, 5],
]]])

# フィルター画像 ( kernel.shape = (2, 2) )
kernel = np.array([
    [1, 0],
    [0, 1]
])

# im2col関数を使って、画像をスライシング
img_col, out_h, out_w= im2col(img, 2, 2, 1, 1, 0, 0)

# フィルター個数=1
k_n = 1 

# フィルターを転置
kernel_col = kernel.reshape(k_n, -1).T 

# 画像のスライシング結果と転置したフィルターの内積を取る
conv = np.dot(img_col, kernel_col)

# 内積結果の行列を整形
conv = conv.reshape(img.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1)
print(conv)


# [[[[ 5. 13. 11.]
#    [ 7.  5. 13.]
#    [ 9.  7.  5.]]]]

 画像は、im2col関数で展開しましたが、フィルターの方は、kernel_col = kernel.reshape(k_n, -1).T だけでOKです。

38
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
38
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?