Posted at

im2col関数の理解

書籍「ゼロから作るディープラーニング」をもとに

畳み込みニューラルネットワーク(CNN、Convolutional Neural Network)を勉強しているのですが、途中に出てくるim2col関数というものについて理解が難しかったため、

自分なりの噛み砕きの経緯を書いてみました。

もしもどなたかの参考になれば幸いです。


im2col関数について

CNNの畳み込み演算において複雑なループ処理を避けるため、

フィルター適用領域ごとに一列のデータになるよう変換する関数です。

入力データとフィルターにこの関数を適用することで、

行列のドット演算で一気に畳み込み演算が行えます。


オリジナルの実装

「ゼロから作る〜」で紹介されている実装はこのようなものです。


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1

img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col

書籍の説明を見ると目指すところは理解できるのですが、

このループ処理部分で行っている事がどうにも理解できませんでした。


引っかかった点



  • y_maxx_maxが何を意味しているのかよくわからない

  • imgのスライスで何を取り出しているのかよくわからない

    y:y_max:stride,x:x_max:stride って何?

  • フィルターサイズについてのループだけで処理できる理由がよくわからない

そこで、まずは自分なりに素朴な発想で実装してみました。


素朴な実装

フィルターを移動させるループ

フィルタ内の各画素をコピーするループ

の順でx方向・y方向、4重ループ処理すればいけるはず...

という発想でforループ部分のみ変更してみたのが以下のim2col_slowです。

def im2col_slow(input_data, filter_h, filter_w, stride=1, pad=0):

N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1

img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

for move_y in range(out_h):
for move_x in range(out_w):
for y in range(filter_h):
for x in range(filter_w):
col[:, :, y, x, move_y, move_x] = \
img[:, :, y + stride * move_y, x + stride * move_x]

col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col

実行してみます。

(見やすくするためにデータ数1、チャンネル数1に絞ってます)

data = np.random.rand(1, 1, 7, 7) * 100 // 1

print('========== input ==========\n', data)
print('=====================')
filter_h = 3
filter_w = 3
stride = 2
pad = 0
col = im2col(data, filter_h=filter_h, filter_w=filter_w, stride=stride, pad=pad)
col2 = im2col_slow(data, filter_h=filter_h, filter_w=filter_w, stride=stride, pad=pad)
print('========== col ==========\n', col)
print('=====================')
print('========== col2 ==========\n', col2)
print('=====================')

同様の結果を得られました。

========== input ==========

[[[[30. 91. 11. 13. 52. 44. 98.]
[99. 6. 35. 41. 97. 72. 79.]
[ 5. 92. 15. 95. 72. 8. 10.]
[68. 5. 86. 25. 69. 46. 70.]
[95. 32. 98. 49. 51. 19. 46.]
[32. 15. 39. 44. 76. 58. 49.]
[43. 47. 95. 1. 1. 12. 21.]]]]
=====================
========== col ==========
[[30. 91. 11. 99. 6. 35. 5. 92. 15.]
[11. 13. 52. 35. 41. 97. 15. 95. 72.]
[52. 44. 98. 97. 72. 79. 72. 8. 10.]
[ 5. 92. 15. 68. 5. 86. 95. 32. 98.]
[15. 95. 72. 86. 25. 69. 98. 49. 51.]
[72. 8. 10. 69. 46. 70. 51. 19. 46.]
[95. 32. 98. 32. 15. 39. 43. 47. 95.]
[98. 49. 51. 39. 44. 76. 95. 1. 1.]
[51. 19. 46. 76. 58. 49. 1. 12. 21.]]
=====================
========== col2 ==========
[[30. 91. 11. 99. 6. 35. 5. 92. 15.]
[11. 13. 52. 35. 41. 97. 15. 95. 72.]
[52. 44. 98. 97. 72. 79. 72. 8. 10.]
[ 5. 92. 15. 68. 5. 86. 95. 32. 98.]
[15. 95. 72. 86. 25. 69. 98. 49. 51.]
[72. 8. 10. 69. 46. 70. 51. 19. 46.]
[95. 32. 98. 32. 15. 39. 43. 47. 95.]
[98. 49. 51. 39. 44. 76. 95. 1. 1.]
[51. 19. 46. 76. 58. 49. 1. 12. 21.]]
=====================


オリジナルとの比較

自分で実装してみたおかげで、オリジナルの実装は

上記の素朴版で行っているフィルター移動分の2重ループをしなくて済むよう効率化したものだと気づきました。

(ストライド幅刻みでスライスすることで、フィルタを移動させて取得する分を一気に取得・コピーしている)

絵にしてみるとこんな感じになるかと思います。


素朴版の実装におけるimgからcolへのコピー

im2col_img_slow.gif


オリジナルim2colにおけるimgからcolへのコピー

im2col_img_org.gif

ループ処理が効率化された結果、あのような実装になっているのだなぁ...

という感じで自分の理解は落ち着きました。