RNN, LSTM and GRU tutorial
を読んで理解したLSTM、GRUのアイデアについてまとめます。
画像は参考サイトからお借りします。
RNN
ポイント
過去の状態は
state=W_xx_t+W_hh_{t−1}+b
予測は
h_t = tanh(state)
で表され、次のシーケンスの学習へと受け渡されていく。
問題は、長期的なシーケンスデータが扱えないのと、勾配消失。
LSTM
RNNのデメリットを改善し
・長期のシーケンスデータを扱える。
・ゲートを使うことで勾配消失を解決した。
というモデル。
LSTMでは、RNNでの
・予測を出す。
・これまでに処理された隠れ層の過去の状態を表す。
という2つの目的を、
h_t, C
という二つの変数に分けて果たす。
キーポイントは
・3つのゲートを使って、セルの状態Cを更新する
ゲート
LSTMには、3つのゲートから入力を受け取り、以下のように表される。
gate_{forget} = σ(W_{fx}X_t + W_{fh}h_{t-1} + b_f) \\
gate_{input} = σ(W_{ix}X_t + W_{ih}h_{t-1} + b_i) \\
gate_{out} = σ(W_{ox}X_t + W_{oh}h_{t-1} + b_o) \\
各入力ゲートでは、活性化関数にシグモイドが使われている。
これは、0に近い=更新しない、1に近い=更新する。というように使えるから。(だと思う。)
(よく見かけるこの画像では、セルの状態を上側のパイプラインで表しており、予測の出力を右下の矢印で表している。)
ゲートのそれぞれの役割は
Forget Gate
・前のセルの状態から、どの部分を保持するか=どの部分を忘却するかを制御する。(状態の忘却)
Input Gate
・新しく計算した情報のどの部分をセルの状態に加えるのかを制御する。(状態の更新)
Out Gate
・セルの状態の、どの部分を予測として出力するかを制御する。
セルの更新
LSTMには、セルの更新と予測の更新があり、以下のように表される。
\tilde{C} = tanh(W_{cx}C_t + W_{ch}h_{t-1} + b_c)\\
C_t = gate_{forget} \cdot C_{t-1} + gate_{input} \cdot \tilde{C}\\
h_t = gate_{out} \cdot tanh(C_t)
忘却
上の式のCtの
gate_{forget} \cdot C_{t-1}
の部分で、前のセルの状態と忘却ゲートの入力をかけて、前のセルのいくつかの部分を忘却する。(=いくつかの部分を保持する。)
更新
\tilde{C} = tanh(W_{cx}C_t + W_{ch}h_{t-1} + b_c)\\
gate_{input} \cdot \tilde{C}\\
入力ゲートと更新元の情報(Cチルド)をかけて、セルの状態に新たな情報を加える。
セルの更新の全体
C_t = gate_{forget} \cdot C_{t-1} + gate_{input} \cdot \tilde{C}\\
前のセルの状態からいくつか忘却したものに新しい情報を加えて、このセルでの状態としている。
予測の出力
h_t = gate_{out} \cdot tanh(C_t)
セルの状態に出力ゲートの情報をかけて、予測を出力する。