CNN畳み込み層の逆伝搬
前提
- 入力画像または特徴マップを以下のように定義する。
\begin{align}
X := \left[\begin{array}{c}
x(1, 1) \dots x(1, W) \\
\vdots \\
x(H, 1) \dots x(H, W)
\end{array}\right]
\end{align}
- 畳み込んだフィルタ(カーネル)を以下のように定義する。
\begin{align}
W := \left[\begin{array}{c}
\omega(1, 1) \dots \omega(1, w) \\
\vdots \\
\omega(h, 1) \dots \omega(h, w)
\end{array}\right]
\end{align}
- 入力とフィルタの畳み込み演算結果の各要素にバイアス$b$を加算した結果を以下のように定義する。
なお、以降、畳み込み演算子は"$*$"と表記する。
また、$B$はすべての要素の値が$b$である$(H - h) \times (W - w)$の行列を表す。
\begin{align}
U := X * W + B = \left[\begin{array}{c}
u(1, 1) \dots u(1, W - w) \\
\vdots \\
u(H - h, 1) \dots u(H - h, W - w)
\end{array}\right]
\end{align}
- 活性化関数$f$をとし、畳み込み層の出力$Y$を以下のように表す。
\begin{align}
Y := f(U) = \left[\begin{array}{c}
f(u(1, 1)) \dots f(u(1, W - w)) \\
\vdots \\
f(u(H - h, 1)) \dots f(u(H - h, W - w))
\end{array}\right]
= \left[\begin{array}{c}
y(1, 1) \dots y(1, W - w) \\
\vdots \\
y(H - h, 1) \dots y(H - h, W - w)
\end{array}\right]
\end{align}
- なお,最終層の出力と教師信号との誤差関数は$E$で表す。
逆伝搬
入力画像のフィルタによる畳み込み演算をニューラルネットワークにモデル化した場合、「重み」はフィルタ行列の各要素$\omega(i,j)$に対応する。
したがって、畳み込み層における誤差逆伝播は$E$をフィルタの各要素について偏微分することに対応する。
偏微分の連鎖律を用いると任意の重み$\omega(m, n)$に対する誤差逆伝搬の結果は以下のように表せる。
\begin{align}
\frac{\partial E}{\partial \omega(m,n)}
&= \sum_{i=1}^{W-w}\sum_{j=1}^{H-h}\frac{\partial E}{\partial u(i,j)}
\cdot \frac{\partial u(i,j)}{\partial \omega(m, n)} \\
&= \sum_{i=1}^{W-w}\sum_{j=1}^{H-h}\frac{\partial E}{\partial u(i,j)}
\cdot x(m + i, n + j)
\end{align}
ここで、以下のパラメータ行列$\Delta$と誤差逆伝搬行列$dE$を定義すると、
\begin{align}
\Delta &:= \left[\begin{array}{c}
\delta(1,1) \cdots \delta(1,W-w) \\
\vdots \\
\delta(H-h,1) \cdots \delta(H-h, W-w)
\end{array}\right]
:= \left[\begin{array}{c}
\frac{\partial E}{u(1,1)} \cdots \frac{\partial E}{u(W-w,1)} \\
\vdots \\
\frac{\partial E}{u(H-h,1)} \cdots \frac{\partial E}{u(H-h,W-w)}
\end{array}\right] \\
dE &:= \left[\begin{array}{c}
\frac{\partial E}{\partial \omega(1,1)} \cdots \frac{\partial E}{\partial \omega(1,W-w)} \\
\vdots \\
\frac{\partial E}{\partial \omega(H-h,1)} \cdots \frac{\partial E}{\partial \omega(H-h,W-w)}
\end{array}\right]
\end{align}
誤差逆伝搬行列もまた、入力とパラメータ行列の畳み込み演算で求めることができる。
\begin{align}
dE = X * \Delta
\end{align}
すなわち、畳み込み層の逆伝搬もまた、畳み込み演算により表現できる。
ただし、上記パラメータ行列$\Delta$の算出にはもう一工夫必要である。
逆伝搬のパラメータ行列
前章で定義したパラメータ行列$\Delta$の任意の要素$\delta(i,j)$は以下となる。
\begin{align}
\delta(i,j) &= \frac{\partial E}{\partial u(i,j)}
\end{align}
ここで誤差関数$E$はCNNの最終層の出力の関数であり、これは$u(i,j)$で直接的に偏微分することができない。
そこで、次の層(順伝搬時に出力を渡す層)の誤差を「逆伝搬」するのである。
次層が畳み込み層である場合
パラメータ行列を求める層を$l$、次層を$l+1$とし、$l$層の$u(i,j)$と$l+1$層の$u^{l+1}(p,q)$の関係を考える。
畳み込み演算の性質から、$l$層の$u(i,j)$は、$l+1$層の$u^{l+1}(p,q)$の以下の範囲に反映される。
\begin{align}
u(i, j) {\rm \quad affects \, to \quad}
\left[\begin{array}{c}
u^{l+1}(i-(h-1), j-(w-1)) \cdots u^{l+1}(i-(h-1), j) \\
\vdots \\
u^{l+1}(i, j-(w-1)) \cdots u^{l+1}(i, j)
\end{array}\right]
\end{align}
ここで、$h,w$は$l+1$層のフィルタのサイズである。
上記から、$\delta(i,j)$は以下のように変形できる。
\begin{align}
\delta(i,j) &= \sum_{p=0}^{h-1}\sum_{q=0}^{w-1}\frac{\partial E}{\partial u^{l+1}(i-p,j-q)}
\cdot \frac{\partial u^{l+1}(i-p,j-q)}{\partial u(i,j)} \\
&= \sum_{p=0}^{h-1}\sum_{q=0}^{w-1}\frac{\partial E}{\partial u^{l+1}(i-p,j-q)}
\cdot (\frac{\partial u^{l+1}(i-p,j-q)}{\partial f(u(i,j))} \cdot \frac{\partial f(u(i,j))}{\partial u(i,j)}) \\
&= \sum_{p=0}^{h-1}\sum_{q=0}^{w-1}\frac{\partial E}{\partial u^{l+1}(i-p,j-q)}
\cdot (w^{l+1}(p,q) \cdot f'(u(i,j))) \\
&= \sum_{p=0}^{h-1}\sum_{q=0}^{w-1} \delta^{l+1}(i-p,j-q)
\cdot (w^{l+1}(p,q) \cdot f'(u(i,j)))
\end{align}
つまり、$l+1$層のパラメータ$\delta^{l+1}$を得ることで、$l$層のパラメータ$\delta$を計算するのである。
したがって、最終層の$\delta$が算出できればすべての層の$\delta$が算出できる。
次層が全結合層である場合
全結合層では、任意の入力ノードはすべての出力ノードに使用されるため、
\begin{align}
\delta(i,j) &= \sum_{k=0}^{K} \frac{\partial E}{\partial u^{l+1}(i,j,k)}
\cdot \frac{\partial u^{l+1}(i,j,k)}{\partial f(u(i,j)}
\cdot \frac{\partial f(u(i,j))}{\partial u(i,j)}) \\
&= \sum_{k=0}^{K} \frac{\partial E}{\partial u^{l+1}(i,j,k)}
\cdot w^{l+1}(i,j,k)
\cdot f'(u(i,j)) \\
&= \sum_{k=0}^{K} \delta^{l+1}(i,j,k)
\cdot w^{l+1}(i,j,k)
\cdot f'(u(i,j))
\end{align}
とできる。ここで、$K$は、出力ノード数を示す。
なお、$l$が最終層の場合は、$E$は出力$y$の関数となるので偏微分と連鎖律により$\delta$を算出可能である。
残された謎
-
パラメータ算出の際、$i - p, j - q$がゼロ以下のときや$i,j$が$l+1$層の出力行列の次元より大きい場合はどう扱うのか?
(この場合,存在しない$\delta^{l+1}$が数式に登場してしまう)
順伝搬時に影響を与えていないと考え、ゼロとする? -
プーリング層により圧縮を行う場合、パラメータをどう紐付ける?