「Deep Learning」本のLSTMに関する状態更新式
Ian Goodfellow, Yoshua Bengio, Aaron Courville著「Deep Learning」の10.10 The Long Short-Term Memory and Other Gated RNNsの中に、LSTMの状態更新式は以下です。
forget gate(f), state(s), external input gate(g), output(h), output gate(q)
f_i^{(t)} = \sigma \Big(
b_i^f + \sum_j U_{i,j}^f x_j^{(t)} +
\sum_j W_{i,j}^f h_j^{(t-1)}
\Big)
s_i^{(t)} =
f_i^{(t)} s_i^{(t-1)} + g_i^{(t)} \sigma \Big(
b_i + \sum_j U_{i,j} x_j^{(t)} +
\sum_j W_{i,j} h_j^{(t-1)}
\Big)
g_i^{(t)} = \sigma \Big(
b_i^g + \sum_j U_{i,j}^g x_j^{(t)} +
\sum_j W_{i,j}^g h_j^{(t-1)}
\Big)
h_i^{(t)}=tanh(s_i^{(t)}) q_i^{(t)}
q_i^{(t)} = \sigma \Big(
b_i^q + \sum_j U_{i,j}^q x_j^{(t)} +
\sum_j W_{i,j}^q h_j^{(t-1)}
\Big)
KerasのLSTM実現方法
Keras発明者Francois Chollet著、「Deep Learning with Python」の6.2.2 Understanding the LSTM and GRU layersにLSTMについて下記の状態更新式が書いてあります。
output_t = activation(dot(state_t, Uo) + dot(input_t, Wo) + dot(C_t, Vo) + bo)
i_t = activation(dot(state_t, Ui) + dot(input_t, Wi) + bi)
f_t = activation(dot(state_t, Uf) + dot(input_t, Wf) + bf)
k_t = activation(dot(state_t, Uk) + dot(input_t, Wk) + bk)
c_t+1 = i_t * k_t + c_t * f_t
これらの式とDeep Learning本の式が似てますが、少なくとも違いは2箇所があると思います。
まず、Francois氏はc_tとstate_tを分けています。
しかし、「Deep Learning」だと、c_tとstate_tは、s(state)だけで十分です。
c_tとstate_tが分けられているから、唯一違う形の式(activation(dot(state_t, Uo) + dot(input_t, Wo) + dot(C_t, Vo) + bo))が出ました。
また、output_tの計算方法が違います。
もしi_tがinput gate, f_tがforget gateだとしたら、k_tはoutput_gateになります。
つまりoutput_tを計算する時、k_tを使うべきです。
Deep Learningの方は、h(ここのoutput_t)を計算する時、output gate q(ここのk_t)を使います。
この理由で、KerasのLSTMの実装は正しいかどうかに疑問を持っています。
Francois氏の式とKerasソースコードが一致しているかを確認するために、Kerasのソースコードも見ましたが、まだ確認できていないです。
Kerasのrecurrent.pyの2015行目~2022行目に更新式がありそうです。後ほどまた確認します。
PyTorchのLSTM実装方法
念の為PyTorchのソースコードも見てみました。
PyTorchのLSTMクラスに綺麗に数学式コメントが書いてありますので、Deep Learning本の式と完全に一致していることを確認しました!
PyTorchのLSTMコメントは以下になります。
PyTorchのLSTMコメントと比べると、KerasのLSTMコメントはあんまりないです。
この理由で今後PyTorchを使おうという気持ちになっています。