ディープラーニングにおける畳み込み層に関するメモです。「ゼロから作るDeep Learning」 は CNN の解説においても秀逸ですが、CNN の実装で用いられている im2col 関数は少し複雑で理解が難しいところがあったので、備忘録メモを残したいと思いました。
im2col 自体の実装は以下の記事に詳しいのですが、ここでは im2col の仕様と利用方法について追いたいと思います。
深層学習/im2colの実装の工夫に驚いた件
numpy備忘録2/transposeは行と列を入れ替えるだけじゃない
プログラムコードの難解な理由は、Numpy の多重配列に関して、ダイナミックに reshape と transpose を適用してデータ構造を変形させているところにあります。理解の一助としてここでは shape の変遷を追っていきたいと思います。抽象的な Haskell プログラムを読むためには 型 を追うことが必要なように、ここでは shape を追うことで理解していきたいと思います。
【関連記事】
MNIST手書き数字のCNN画像認識 - Qiita
CNN 畳み込み層のメモ - Qiita
Softmax+CrossEntropy の実装 - Qiita
機械学習のウォーミングアップ(Numpy) - Qiita
1. numpy.reshape
reshape は多重配列の shape を変更するために用いられますが、order を意識した変形なので、不可逆的でなく可逆的な操作といえるでしょう。
numpy.reshape(a, newshape, order='C')
指定した順番に配列の要素を読み込み、並び替えていきます。’C’は、読み込みと並び替えをC言語方式で行います。この場合、最後の軸を最も早くインデックスし、最初の軸を最も遅くインデックスします。(C 以外は略)
少し試してみます。
import numpy as np
x = np.array([[0, 1], [2, 3], [4,5]])
print(x.shape)
(3, 2)
y = x.reshape(2,3)
print(y.shape)
print(y)
(2, 3)
[[0 1 2]
[3 4 5]]
z = x.reshape(-1)
print(z.shape)
print(z)
w = z.reshape(3,2)
print(w.shape)
print(w)
(6,)
[0 1 2 3 4 5]
(3, 2)
[[0 1]
[2 3]
[4 5]]
2. im2col の仕様
「ゼロから作るDeep Learning」 では畳み込み層とプーリング層の処理に im2col が使われています。im2col の実装は本からは割愛されているのですけど、git にあるソースコードにはちゃんと載っています。ここでも、以下の im2col を前提として話を進めていきます。
im2col は4次元の入力データを2次元の行列に展開するものです。カーネルとの演算を簡単にします。
col = im2col(x, FH, FW, S, P)
この時 col.shape = (N*OH*OW, C*FH*FW)
但し、
S = stride, P = pad
N, C, H, W = x.shape
OH = (H + 2*P - FH) / S + 1
OW = (W + 2*P - FW) / S + 1
特に、S=1, P=0 の場合は以下のようになる。
OH = H - FH + 1
OW = W - FW + 1
確認してみます。
以下、S=1, P=0 の場合
◎ N=1, C=3, H=7, W=7, FH=5, FW=5
OH = 7-5+1 = 3
OW = 7-5+1 = 3
col.shape = (1x3x3, 3x5x5) = (9, 75)
◎ N=10, C=3, H=7, W=7, FH=5, FW=5
OH = 7-5+1 = 3
OW = 7-5+1 = 3
col.shape = (10x3x3, 3x5x5) = (90, 75)
これは「ゼロから作るDeep Learning」「7.4.3 Convolutionレイヤの実装」にある例題と一致しています。
3. 畳み込み層( Convolutionレイヤ)
畳み込み層とプーリング層で定義された model の処理を理解するためには、入力の shape と各層から出力される shape を追う必要があります。
畳み込み層では、(N, C, H, W) の入力に対して、出力が (N, FN, OH, OW) となることを確認します。入力も出力も 4次元データ です。
以下が 畳み込み処理 です。
N, C, H, W = x.shape # x: 入力データ、4次元
FN, C, FH, FW = W.shape # W: フィルター(カーネル)
OH = (H + 2*P - FH) / S + 1 # OH = (H - FH) + 1 if S=1, P=0
OW = (W + 2*P - FW) / S + 1 # OW = (W - FW) + 1 if S=1, P=0
col = im2col(x, FH, FW, S, P) # col.shape = (N*OH*OW, C*FH*FW) (1)
col_w0 = self.W.reshape(FN,-1) # col_w0.shape = (FN, C*FH*FW)
col_w = col_w0.T # col_w.shape = (C*FH*FW, FN) (2)
out = np.dot(col, col_W) + b # out.shape = (N*OH*OW, FN) (3)
out = out.reshape(N, OH, OW, -1) # out.shape = (N, OH, OW, FN)
out = transpose(0, 3, 1, 2) # out.shape = (N, FN, OH, OW)
(3) では、(1)の 行列 (N*OH*OW, C*FH*FW) と(2)の 行列 (C*FH*FW, FN) の掛け算を行うことで、 一度で畳み込み処理を行っています。(2)の C*FH*FW はチャンネルを含めたカーネル行列要素に相当していることに注意してください。
特に、S=1, P=1, FH=FW=3 の場合、画像サイズは変わらない。
OH = H
OW = W
4. プーリング層(Pooling レイヤ)
プーリング層でも、(N, C, H, W) の入力に対して、出力が (N, C, OH, OW) となることを確認します。入力も出力も 4次元データ です。また P = 0 です。
以下が プーリング処理 です。
N, C, H, W = x.shape # x: 入力データ、4次元
PH, PW # プーリング領域の縦、横
OH = 1 + (H - PH) / S # OH = H/2 if PH = S = 2
OW = 1 + (W - PW) / S # OW = W/2 if PW = S = 2
col = im2col(x, PH, PW, 1, 0) # col.shape = (N*OH*OW, C*PH*PW)
col = col.reshape(-1, PH*PW) # col.shape = (N*OH*OW*C, PH*PW) チャンネル独立
out = np.max(col, axis=1) # out.shape = (N*OH*OW*C, 1)
out = out.reshape(N, OH, OW, C) # out.shape = (N, OH, OW, C)
out = out.transpose(0, 3, 1, 2) # out.shape = (N, C, OH, OW)
畳み込み層の output をプーリング層 の input として、変数を共通化すれば、プーリング層の出力は以下のように表記できます。
N, FN, OH, OW = x.shape # x = 畳み込み層の output
PH, PW # プーリング領域の縦、横
OHP = 1 + (OH - PH) / S # OH2 = OH/2 if PH = S = 2
OWP = 1 + (OW - PW) / S # OW2 = OW/2 if PW = S = 2
out.shape = (N, FN, OHP, OWP) # (N, FN, OH/2, OW/2) if PH = PW = S = 2
特に、 PH = PW = S = 2 の場合は以下の通り
out.shape = (N, FN, OH/2, OW/2)
(N, FN, OH2, OW2) の4次元の output は次のAffine層に流される前に、2次元データ (N, FN*OH2*OW2) に変形される必要があります。
今回は以上です