はじめに
この記事はRNNの一つであるGRUの誤差逆伝播を計算グラフをもとに読み解くものです。噂によるとゼロつく2の巻末に載っているらしいのですが、僕みたいに手元に用意できない人へ向けて書きます。
GRUの概略
GRUはゲート付きRNNのうちの一つであり、ResetゲートとUpdateゲートの二つの構造を持つことが特徴です。同じくゲートつきRNNであるLSTMとの違いはCEC(記憶セル)が無いということにあります。このことにより、過去の情報が残りにくく、また、忘却と記憶がトレードオフになってしまうというデメリットを生じるものの、計算コストをLSTMより低くすることができます。
構造
GRUユニットは下図のようになります。
$$\begin{aligned}
z &= \sigma(x_t W_{xz} + h_{prev}W_{hz} + b_z) \\
r &= \sigma(x_t W_{xr} + h_{prev}W_{hr} + b_r) \\
\hat{h} &= \tanh (x_t W_{xh} + (r \odot h_{prev})W_{hh} + b_h) \\
h_{next} &= (1-z) \odot h_{prev} + z \odot \hat{h}
\end{aligned}$$
このユニット一つで学習することもできますし、並列に複数並べることもできます。また、双方向RNNのユニットとしても用いることができます。
計算グラフ
それではGRUユニットの計算グラフから誤差逆伝播を計算します。
求めたいのは時刻$t$における入力値$x_t$の勾配$dx$と時刻$t-1$からの入力$h_{prev}$の勾配$dh_{prev}$、さらに各全結合層における重み$W_{xr}, W_{xz}, W_{xh}$、$W_{hr}, W_{hz}, W_{hh}$の勾配$dW_{xr}, dW_{xz}, dW_{xh}$、$dW_{hr}, dW_{hz}, dW_{hh}$です。
まずは計算グラフの全貌です。グラフの中では簡略化のため二つのシグモイドノード($\sigma$)の出力を$r, z$、$tanh$の出力を$\hat{h}$とします。
全てをいっぺんに考えるのは大変なので後ろから順番に区切って計算していきます。
出力部の逆伝播
まずは$d\hat{h}$を求めます。
$$ \begin{equation} \begin{split}
d\hat{h} &= \frac{\partial L}{\partial h_{next}} \frac{\partial h_{next}}{\partial \hat{h}}\
&= dh_{next} \times z
\end{split} \end{equation}
$$
最初の+ノードの勾配は$dh_{next}$がそのまま流れるだけ、×ノードは反対側からの入力の掛け算になります。
また、この部分の$dh_{prev}$は以下のようになりますが、$h_{prev}$は$x$とともに複数個所の入力がありますので、最終的にはそれらの合計となります。そのため$dh_{prev}, dx$には番号を振って区別しておきます。
$$ \begin{equation} \begin{split}
dh_{prev1} &= \frac{\partial L}{\partial h_{next}} \frac{\partial h_{next}}{\partial h_{prev}}\
&= dh_{next} \times (1 - z)
\end{split} \end{equation} $$
tanhノードの逆伝播
次にtanh周りを計算します。$f(x) = \tanh x$の微分は、証明は省きますが以下の形になります。
$$ f'(x) = 1 - \tanh^2 x $$
簡略化のため$dt$という勾配を図の位置に導入します(この後何回か$dt$が出てきますが、全て別物です)。つまり$t$はtanhノードへの入力値です。
繰り返しますが、求めたいのは全結合層の重みの勾配と二つの入力値の勾配です。
$$\begin{equation}\begin{split}
dt &= \frac{\partial L}{\partial \hat{h}} \frac{\partial \hat{h}}{\partial t}\
&= d\hat{h} \times (1 - \tanh^2 t)\
&= d\hat{h} \times (1 - \hat{h}^2)
\end{split} \end{equation}$$
$d\hat{h}$は前節で求めていました。あとはdot(行列積)ノードの逆伝播を求めれば全結合層の重みと二つの入力値の一部分が求まります。
$$ \begin{equation} \begin{split}
dW_{hh} &= \frac{\partial L}{\partial t} \frac{\partial t}{\partial W_{hh}}\
&= (r \odot h_{prev})^T * dt
\end{split} \end{equation} $$
dotノードの逆伝播は掛け算の拡張と考えられますが、行列積ですので掛ける方向と形に注意が必要です。この場合は、ノードへの入力値が転置になります。また、このノードへの$h_{prev}$からの入力値を$h_r$とすると、
$$ \begin{equation} \begin{split}
dh_r &= \frac{\partial L}{\partial t} \frac{\partial t}{\partial h_r}\
&= dt * W_{hh}^T
\end{split} \end{equation} $$
となります。これを利用して$dh_{prev}$の一部分を求めます。
$$
dh_{prev2} = \frac{\partial L}{\partial h_r} \frac{\partial h_r}{\partial h_{prev}}
= dh_r \times r
$$
$x$側も同様に計算します。
$$ \begin{aligned}
dW_{xh} &= \frac{\partial L}{\partial t} \frac{\partial t}{\partial W_{xh}} = x^T * dt\\
dx_1 &= \frac{\partial L}{\partial t} \frac{\partial t}{\partial x} = dt * W_{xh}^T
\end{aligned} $$
シグモイドノード(z)の逆伝播
次はz周辺の逆伝播を計算します。$dz$は二つの勾配が合流してくるため、和を取ります。
$$ \begin{aligned}
dz &= \frac{\partial L}{\partial h_{next}} \frac{\partial h_{next}}{\partial z} \\
&= dh_{next} \times \hat{h} + (- dh_{next} \times h_{prev})
\end{aligned} $$
今節ではtはシグモイドノードへの入力となります。したがって$dt$は
$$ \begin{aligned}
dt &= \frac{\partial L}{\partial z} \frac{\partial z}{\partial t}\\
&= dz \times z \times (1 - z)
\end{aligned} $$
と計算できます(1-ノードでは符号が反対になります)。また、シグモイド関数$f(x) = \frac{1}{1 + exp(-x)}$の微分が以下のようになることを利用しました。
$$ f'(x) = (1 - f(x)) \cdot f(x) $$
後はdotノードのみの計算なので結果のみを書きます。
$$ \begin{aligned}
dW_{hz} &= h_{prev}^T * dt\\
dW_{xz} &= x^T * dt\\
dh_{prev3} &= dt * W_{hz}^T\\
dx_2 &= dt * W_{xz}^T
\end{aligned} $$
右肩の$T$は転置を意味します。
シグモイドノード(r)の逆伝播
最後にr周辺の逆伝播を計算します。こちらはtanhノードの計算で定義した$dh_r$を使います。それ以外はz周辺と同じなので結果のみ記します。
$$\begin{aligned}
dr &= dh_r \times h_prev\\
dt &= dr \times r \times (1 - r)\\
dW_{hr} &= h_{prev}^T * dt\\
dW_{xr} &= x^T * dt\\
dh_{prev4} &= dt * W_{hr}^T\\
dx_3 &= dt * W_{xr}^T\\
\end{aligned}$$
入力値の逆伝播
これまで求めてきた$dh_prev$と$dx$は各ノードへの入力値でした。計算グラフ上は別物に見えますが、実際には同一の入力ですので、これらを合計する必要があります。
$$\begin{aligned}
dh_{prev} &= dh_{prev1} + dh_{prev2} + dh_{prev3} + dh_{prev4}\\
dx &= dx_1 + dx_2 + dx_3
\end{aligned}$$
これでGRUユニットの逆伝播は終わりです。お疲れさまでした。