#1.はじめに
im2col は 4次元配列 (ミニバッチサイズ、チャンネル数、縦幅、横幅)を、行列 (画像枚数*畳み込み回数、フィルターの要素数)に変換し、フィルターも行列(フィルターの要素数、フィルターの枚数)に変換することで、畳み込み演算をただの行列積の計算にし、処理速度を劇的に改善するアルゴリズムです。
im2colは有名なアルゴリズムで、ゼロから作る Deep Learning のリポジトリ
(common/util.py)に載っているので今回コードを見たところ、実装の工夫に驚いたので、備忘録として残します。
#2.im2col関数
def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):
N, C, H, W = input_data.shape
out_h = (H + 2*pad_h - filter_h)//stride_h + 1
out_w = (W + 2*pad_w - filter_w)//stride_w + 1
img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride_h*out_h
for x in range(filter_w):
x_max = x + stride_w*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col, out_h, out_w
ゼロから作る Deep Learning のリポジトリにある、im2col関数です(一部表現を補足しています)。
この関数の原理を具体的な例で説明します。input_dataのシェイプは(1,1,4,4)、フィルターは(2,2)、ストライド(1,1)、パディングはなし、という設定にすると、
スライシングを9回行うと、9行×4列の行列が出来上がりますので、これにフィルターの行列との内積をとれば、一気に畳み込み計算が完了となるわけです。
ただし、実装はこれをそのまま行っておらず驚きの工夫がされています。それをこれから見て行きます。
def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):
# input_dataから、バッチサイズN, チャンネル数C, 画像高さH, 画像幅W を取得
N, C, H, W = input_data.shape
# 畳み込み処理後の出力の高さ out_h, 幅 out_w を計算
out_h = (H + 2*pad_h - filter_h)//stride_h + 1
out_w = (W + 2*pad_w - filter_w)//stride_w + 1
# 画像のパディング(高さ方向 pad_h, 幅方向 pad_w)
img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')
# 計算結果を保存するためのゼロ行列の作成
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
コードの冒頭部分です。np.padの詳細を知りたい場合はこちらを参照
して下さい。後は、特に目を引く部分はないと思います。問題は、この後です。
for y in range(filter_h):
# 縦方向のスライシング範囲の最大値y_maxを求める
y_max = y + stride_h*out_h
for x in range(filter_w):
# 横方向のスライシング範囲の最大値x_maxを求める
x_max = x + stride_w*out_w
# スライシング結果をゼロ行列に格納
# y から y_max まで stride_h 間隔でスライシング
# x から x_max まで stride_w 間隔でスライシング
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]
# col の各行列の要素を並べ替え、リシェイプ
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col, out_h, out_w
先程同様、input_dataのシェイプは(1,1,4,4)、フィルターは(2,2)、ストライド(1,1)、パディングはなし、という設定とすると、ここはどういう動きになるかと言うと、
filter_h = 2, filter_w = 2 なので、2重のforループは4回しか回りません。そして、スライシングの範囲は y:y+filter_h:stride_h, x:x+filter_w:stride_w
ではなく、もっと範囲は大きくy:y_max:stride_h, x:x_max:stride_w
なのです。
一体何をやっているか簡単に言うと、
何これ、凄い!!! 9回のスライシングが4回になっています。処理時間の掛かる for文 を出来るだけ少なくする工夫がされているわけです。
ストライドが2だとおかしくならない?とか思いますよね。試しにやってみますね。
つまり、ファイルターの要素数分だけ forループを回せば、必要な計算が出来てしまうわけです。これ、ピクセルサイズが大きくなると絶大な効果を発揮することは、簡単にお分かり頂けると思います。
なお、コードの**col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
**の詳細については、こちらを参照して下さい。
im2colのアルゴリズムを聞いた時、それは凄いアイディアだと驚いたのですが、その実装を見て再び驚きました。
最後に、畳み込み演算全体のサンプルコードを書いておきます。
import numpy as np
def im2col(input_data, filter_h, filter_w, stride_h=1, stride_w=1, pad_h=0, pad_w=0):
N, C, H, W = input_data.shape
out_h = (H + 2*pad_h - filter_h)//stride_h + 1
out_w = (W + 2*pad_w - filter_w)//stride_w + 1
img = np.pad(input_data, [(0,0), (0,0), (pad_h, pad_h), (pad_w, pad_w)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride_h*out_h
for x in range(filter_w):
x_max = x + stride_w*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_w]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col, out_h, out_w
# 畳み込み演算する画像 ( img.shape = (1, 1, 4, 4) )
img =np.array([[[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 0, 1],
[2, 3, 4, 5],
]]])
# フィルター画像 ( kernel.shape = (2, 2) )
kernel = np.array([
[1, 0],
[0, 1]
])
# im2col関数を使って、画像をスライシング
img_col, out_h, out_w= im2col(img, 2, 2, 1, 1, 0, 0)
# フィルター個数=1
k_n = 1
# フィルターを転置
kernel_col = kernel.reshape(k_n, -1).T
# 画像のスライシング結果と転置したフィルターの内積を取る
conv = np.dot(img_col, kernel_col)
# 内積結果の行列を整形
conv = conv.reshape(img.shape[0], out_h, out_w, -1).transpose(0, 3, 2, 1)
print(conv)
# [[[[ 5. 13. 11.]
# [ 7. 5. 13.]
# [ 9. 7. 5.]]]]
画像は、im2col関数で展開しましたが、フィルターの方は、kernel_col = kernel.reshape(k_n, -1).T
だけでOKです。