背景
現在自然言語の生成AIを学習しており、 ゼロから作るDeep Learning ❷ でRNNについて勉強していました。
勾配消失の対策を施したLSTM等を学習するにあたり、RNNでの勾配爆発・勾配消失が起こる原因は重要な要素です。
本の中で勾配爆発・勾配消失が起こる原因は言及されていますが、自分の中でのRNNの知識整理を兼ねて備忘録として残します。
RNNとは
RNNは、直前の隠れ状態と現在の入力から新しい隠れ状態と出力を計算し、その隠れ状態を次の時刻へ受け渡す処理を繰り返すモデルです。
今回シンプルなRNNとして以下のような構造を扱います。
パラメータの説明
- $x_t$: 時刻tの入力
- $h_t$: 時刻tの隠れ状態
- $h_{t-1}$: 直前の隠れ状態
- $W_x$: 入力に対する重み行列
- $W_h$: 隠れ状態に対する重み行列
- $b$: バイアス項
- $\tanh$: 活性化関数(双曲線正接関数)
上記のRNNは隠れ状態 $h$ 用と入力 $x$ 用に2種類の重み $W_h$ 、 $W_x$ 、およびバイアス項 $b$ を持ちます。
順伝播のプログラム例
RNNの出力および $h_t$ の計算は以下のように実装できます。
def forward(x_t, h_prev, W_h, W_x, b):
"""
RNNの順伝播を計算する
Args:
x_t: 時刻tの入力
h_prev: 前の隠れ状態 h_{t-1}
W_h: 隠れ状態に対する重み行列
W_x: 入力に対する重み行列
b: バイアス項
Returns:
h_t: 現在の隠れ状態
"""
# 隠れ状態の計算: W_h * h_{t-1}
h_term = np.dot(W_h, h_prev)
# 入力の計算: W_x * x_t
x_term = np.dot(W_x, x_t)
# 線形結合: W_h * h_{t-1} + W_x * x_t + b
linear_combination = h_term + x_term + b
# 活性化関数を適用: tanh(線形結合)
h_t = np.tanh(linear_combination)
return h_t
RNNのパラメータ更新を除いた逆伝播のフロー
逆伝播時は最終出力から始まって、後ろから伝わってきた勾配を使って連鎖律により各パラメータに対する勾配を計算します。
前側に伝わる部分のみに絞ると以下のような経路で伝わっていきます。
足し算のステップでは基本的にそのまま値を流すだけなので、前側に伝わる部分のみに絞った計算のステップは以下の通りです。
- $\tanh$の逆伝播 $dh_t \times (1 - h_t^2)$ を計算する
-
- の結果と重み行列 $W_h$ の転置で行列積をとる
逆伝播のプログラム例
上記の計算フローに沿った逆伝播の実装例を以下に示します。
def backward(dh_t, h_t, W_h):
"""
RNNの1ステップの逆伝播を計算する
Args:
dh_t: 時刻tの隠れ状態に対する勾配
h_t: 時刻tの隠れ状態(順伝播で計算済み)
W_h: 隠れ状態に対する重み行列
Returns:
dh_prev: 前の隠れ状態 h_{t-1} に対する勾配
"""
# tanhの逆伝播を計算: dh_t * (1 - h_t^2)
d_linear = dh_t * (1 - h_t**2)
# 重み行列の転置との行列積で前の隠れ状態への勾配を計算
dh_prev = np.dot(W_h.T, d_linear)
return dh_prev
上記のフローとコード内の計算の要素を確認すると、勾配消失・勾配爆発が起こる原因は、主に$\tanh$の微分を使った計算と重み行列の計算にありそうです。
勾配消失・勾配爆発とは
勾配消失と勾配爆発は、ニューラルネットワークの学習において深刻な問題となる現象です。
勾配消失とは、逆伝播において勾配が層を遡るにつれて指数的に小さくなってしまう現象です。この結果、初期の層(RNNでは過去の時刻)でのパラメータ更新が極めて小さくなり、長期依存関係を学習できなくなります。
勾配爆発とは、逆伝播において勾配が層を遡るにつれて指数的に大きくなってしまう現象です。この結果、パラメータの更新が不安定になり、学習が収束しなくなります。
これらの問題により、RNNは系列の長期的な依存関係を効果的に学習することが困難になります。
勾配消失・勾配爆発が起こる理由
勾配消失
勾配消失は繰り返しの中で何度も小さい値をかけられることで起こります。
シンプルなRNNにおける勾配消失の原因は以下があげられます。
- $\tanh$の微分では0~1の値を取るため、掛け算した際に基本的に勾配が小さくなる
- 勾配に重みをかける部分では、重みが小さいほど勾配が小さくなる
勾配消失の実験コード
以下のコードで勾配消失の様子を確認できます。
import numpy as np
def backward(dh_t, h_t, W_h):
"""
RNNの1ステップの逆伝播を計算する
"""
d_linear = dh_t * (1 - h_t**2)
dh_prev = np.dot(W_h.T, d_linear)
return dh_prev
if __name__ == "__main__":
print("=== 勾配消失・爆発のシンプルなデモ ===")
# パラメータ設定
hidden_size = 3
time_steps = 10
# 重み行列(勾配消失を起こしやすい小さな値)
W_h = np.array([[0.3, 0.1, 0.05],
[0.1, 0.2, 0.1],
[0.05, 0.1, 0.3]])
# ダミーの隠れ状態(tanhの出力なので-1から1の範囲)
h_states = np.array([
[0.5, -0.3, 0.8],
[0.2, 0.7, -0.4],
[0.9, -0.1, 0.6],
[-0.3, 0.4, 0.2],
[0.6, -0.8, 0.1],
[0.1, 0.3, -0.7],
[-0.5, 0.8, 0.4],
[0.7, -0.2, 0.9],
[-0.1, 0.5, -0.3],
[0.4, -0.6, 0.7]
])
# 初期勾配
dh_current = np.array([1.0, 1.0, 1.0])
print("時刻 | 勾配の大きさ")
print("-" * 20)
# 逆伝播を実行
for t in reversed(range(time_steps)):
magnitude = np.linalg.norm(dh_current)
print(f"{t:4d} | {magnitude:.6f}")
# 次のステップの勾配を計算
dh_current = backward(dh_current, h_states[t], W_h)
print(f"最終 | {np.linalg.norm(dh_current):.6f}")
勾配消失の実験結果
=== 勾配消失・爆発のシンプルなデモ ===
時刻 | 勾配の大きさ
--------------------
9 | 1.732051
8 | 0.502864
7 | 0.196388
6 | 0.045217
5 | 0.012520
4 | 0.004473
3 | 0.001267
2 | 0.000502
1 | 0.000126
0 | 0.000041
最終 | 0.000011
勾配爆発
勾配爆発は繰り返しの中で何度も大きい値をかけられることで起こります。
シンプルなRNNにおける勾配爆発の原因は以下が挙げられます。
- $\tanh$の微分は最大1の値を取るため、それ自体は勾配を抑制する傾向があるが、入力が0付近の場合は1に近い値となる
- 特に重みが非常に大きい場合、$\tanh$の微分を適用した後の勾配に重み行列をかけることで、$\tanh$の微分による抑制効果を上回り、勾配が爆発的に増大する
勾配爆発の実験コード
勾配消失のコードに利用している重みを以下のように変更すると勾配爆発の様子が確認できます。
...
if __name__ == "__main__":
...
# 重み行列(勾配爆発を起こしやすい大きな値)
W_h = np.array([[2.0, 1.5, 1.0],
[1.5, 3.0, 2.0],
[1.0, 2.0, 2.5]])
...
勾配爆発の実験結果
=== 勾配消失・爆発のシンプルなデモ ===
時刻 | 勾配の大きさ
--------------------
9 | 1.732051
8 | 6.252082
7 | 30.060867
6 | 104.301982
5 | 342.062643
4 | 1523.238278
3 | 5296.997016
2 | 26764.259254
1 | 108428.716262
0 | 426360.436556
最終 | 1662759.072585
終わりに
今回はRNNの勾配消失・勾配爆発の原因について調査しました。
1や2程度の大きさの値が入った重みで勾配爆発が起こるというのは少し驚きましたが、実際にコードで試してみると様々な発見があるため、今後も気になったら実験するようにしたいと思います。
アテンションの仕組みも興味があるので、こちらも順を追って調査してみようと思います。

