Help us understand the problem. What is going on with this article?

pythonでRNN実装

More than 3 years have passed since last update.

はじめに

pythonでRNNを実装しました.
教科書として『深層学習』を使いました.

本記事の構成

  • はじめに
  • RNN
    • 順伝播計算
    • 逆伝播計算
    • 重みの更新
  • pythonでの実装
  • 結果
    • ロス
    • 系列データの予測
    • sin波の予測
  • おわりに

RNN

RNNとは,系列データを扱う再帰型のニューラルネットワークです.
系列データの例として,音声や言語,動画像などが挙げられます.
このような系列データは,サンプルごとに系列の長さが異なり,系列内の要素の順番に意味があることが特徴です.
RNNは,系列データの特徴をうまく扱うことを可能にします.

順伝播計算

まずは,文字の定義をします.
入力層,中間層,出力層の各ユニットのインデックスをそれぞれ $i, j, k$ で表します.
また,時刻 $t$ における入力,中間層の入出力,出力層の入出力,教師を以下のように表します.

  • 入力: $\boldsymbol x^t = (x_i^t)$
  • 中間層の入出力: $\boldsymbol u^t = (u_j^t)$,$\boldsymbol z^t = (z_j^t)$
  • 出力層の入出力: $\boldsymbol v^t = (v_k^t)$,$\boldsymbol y^t = (y_k^t)$
  • 教師: $\boldsymbol d^t = (d_k^t)$

さらに,入力 - 中間層間の重み,中間 - 中間層の帰還路の重み,中間 - 出力層間の重みを以下のように表します.

  • 入力 - 中間層間の重み: $\boldsymbol W^{(in)} = (w_{ji}^{(in)})$
  • 中間 - 中間層の帰還路の重み: $\boldsymbol W = (w_{jj'})$
  • 中間 - 出力層間の重み: $\boldsymbol W^{(out)} = (w_{kj}^{(out)})$

上記の定義を含めたネットワークの構造を下図に示します.

rnn.png

では,順伝播の計算について説明します.
中間層のユニット $j$ の入出力 $u_j^t$ および $z_j^t$ は,入力 $\boldsymbol x^t$ および一つ前の時刻での中間層の値 $\boldsymbol z^{t-1}$ を用いて以下のように表せます.
$f$ は,活性化関数です.

\begin{align}
& u_j^t = \sum_i w_{ji}^{(in)}x_i^t + \sum_{j'} w_{jj'}z_{j'}^{t-1} \tag{1} \\
& z_j^t = f(u_j^t) \tag{2}
\end{align}

ただし,$t = 1$ のとき,それまでの入力は存在しないため,$\boldsymbol z^0 = \boldsymbol 0$ とします.
式(1), (2)をまとめ,行列表現すると中間層の出力 $\boldsymbol z^t$ は以下のようになります.

\boldsymbol z^t = \boldsymbol f(\boldsymbol W^{(in)}\boldsymbol x^t + \boldsymbol W \boldsymbol z^{t-1}) \tag{3}

出力層のユニット $k$ の入出力 $v_k^t$ および $y_k^t$ は,中間層の出力 $\boldsymbol z^t$ を用いて以下のように表せます.

\begin{align}
& v_k^t = \sum_j w_{kj}^{(out)}z_j^t \tag{4} \\
& y_k^t = f^{(out)}(v_k^t) \tag{5}
\end{align}

式(4), (5)をまとめ,行列表現すると出力 $\boldsymbol y^t$ は以下のようになります.

\boldsymbol y^t = \boldsymbol f^{(out)}(\boldsymbol W^{(out)}\boldsymbol z^t) \tag{6}

逆伝播計算

誤差逆伝播法により各層の各ユニットにおける誤差から勾配を計算します.
今回 BPTT法(backpropagation through time) を説明します.
BPTT法は,RNNを下図のように時間方向に展開し,誤差逆伝播の計算を行います.

bptt.png

中間 - 出力層間における勾配

\begin{align}
\frac{\partial E}{\partial w_{kj}^{(out)}} &= \sum_{t=1}^T \frac{\partial E}{\partial v_k^t} \frac{\partial v_k^t}{\partial w_{kj}^{(out)}} \\
&= \sum_{t=1}^T \frac{\partial E}{\partial y_k^t} \frac{\partial y_k^t}{\partial v_k^t} \frac{\partial v_k^t}{\partial w_{kj}^{(out)}} \\
&= \sum_{t=1}^T \frac{\partial E}{\partial y_k^t} f^{(out)'}(v_k^t)z_j^t \\
&= \sum_{t=1}^T \delta_k^{(out), t} z_j^t \tag{7}
\end{align}

中間 - 中間層の帰還路における勾配

\begin{align}
\frac{\partial E}{\partial w_{jj'}} &= \sum_{t=1}^T \frac{\partial E}{\partial u_j^t} \frac{\partial u_j^t}{\partial w_{jj'}} \\
&= \sum_{t=1}^T \biggl(\sum_{k'} \frac{\partial E}{\partial v_{k'}^t} \frac{\partial v_{k'}^t}{\partial z_j^t} \frac{\partial z_j^t}{\partial u_j^t} + \sum_{j''} \frac{\partial E}{\partial u_{j''}^{t+1}} \frac{\partial u_{j''}^{t+1}}{\partial z_j^t} \frac{\partial z_j^t}{\partial u_j^t} \biggr) \frac{\partial u_j^t}{\partial w_{jj'}} \\
&= \sum_{t=1}^T \biggl(\sum_{k'} \frac{\partial E}{\partial v_{k'}^t} \frac{\partial v_{k'}^t}{\partial z_j^t} + \sum_{j''} \frac{\partial E}{\partial u_{j''}^{t+1}} \frac{\partial u_{j''}^{t+1}}{\partial z_j^t} \biggr) \frac{\partial z_j^t}{\partial u_j^t} \frac{\partial u_j^t}{\partial w_{jj'}} \\
&= \sum_{t=1}^T \biggl(\sum_{k'} \delta_{k'}^{(out), t} w_{k'j} + \sum_{j''} \delta_{j''}^{t+1} w_{j''j}  \biggr) f'(u_j^t) z_j^{t-1} \\
&= \sum_{t=1}^T \delta_j^t z_j^{t-1} \tag{8}
\end{align}

入力 - 中間層間における勾配

\begin{align}
\frac{\partial E}{\partial w_{ji}^{(in)}} &= \sum_{t=1}^T \frac{\partial E}{\partial u_j^t} \frac{\partial u_j^t}{\partial w_{ji}^{(in)}} \\
&= \sum_{t=1}^T \delta_j^t x_i^t \tag{9}
\end{align}

中間 - 中間層の帰還路における勾配では,一つ先の時刻における中間層の誤差 $\boldsymbol \delta^{t+1}$ の項が存在します.
$t = T$ の場合,$\boldsymbol \delta^{T+1} = \boldsymbol 0$ とします.
$t = T, T - 1, \cdots, 2, 1$ の順に誤差 $\delta$ を伝播させ,各時刻における勾配の総和をとります.
この勾配の総和を使って重みを更新します.

重みの更新

式(7), (8), (9)より,各層間における勾配が求められます.
この勾配を使って,以下の式より重みを更新します.

w_{new} = w_{old} - \varepsilon \frac{\partial E}{\partial w_{old}} \tag{10}

pythonでの実装

サイン波の予測を実装しました.
コードは ここ にあげてあります.
$x^t = \sin \bigl(\theta + (t - 1)\cdot\Delta \theta\bigr)$ なる $x^t$ を入力し,一つ先の時刻の $y^t = \sin \bigl(\theta + t\cdot\Delta \theta\bigr)$ なる $y^t$ を予測します.
ある入力 $x^t$ に対して $y^t$ を出力するとき,入力 $x^t$ だけでは予測方向を決定することができません.
しかし,$x^t$ および $z^{t-1}$ を入力とした場合,一つ前の時刻の情報を利用できるため,予測方向を決定できます.

結果

学習データは,$\boldsymbol x = (x^1, x^2, \cdots, x^T)$ の系列データを $1$ サンプルとしています.
$\theta$ は,$[0, \pi]$ の範囲から乱数で決定します.

\begin{align}
& \{\theta \,|\, 0 \leq \theta \leq \pi\}, \Delta \theta = \frac{\pi}{6} \\
& x^1 = \sin \theta \\
& x^2 = \sin \bigl(\theta + \Delta \theta\bigr) \\
& \cdots \\
& x^t = \sin \bigl(\theta + (t - 1)\cdot \Delta \theta\bigr) \tag{11} \\
& \cdots \\
& x^T = \sin \bigl(\theta + (T - 1)\cdot \Delta \theta\bigr) \\
\end{align}

学習のパラメータは以下になります.
学習データ:$7000$ サンプル
学習率:$\varepsilon = 0.05$
正則化係数:$\lambda = 0.001$
エポック:$30$

ロス

学習時のロスは以下のようになりました.

loss.png

系列データの予測

次に,時系列データを入力したとき正しく予測できたかを確かめます.
入力を $x$,出力を $y$,理想出力を $d$ で表します.

\begin{align}
& x^1 = \sin\Bigl(\frac{1}{6}\pi\Bigr) \to y^1 = 0.43831960 \ \ (d^1 = 0.86602540) \tag{12} \\
& x^1 = \sin\Bigl(\frac{5}{6}\pi\Bigr) \to y^1 = 0.43831960 \ \ (d^1 = 0.0) \tag{13}
\end{align}

$x$ の値が等しいため,$y$ の値は一致しています.
また,系列の情報がないため,予測 $y$ は理想出力 $d$ から離れています.
そこで,一つ前の時刻のデータも入力してみます.

\begin{align}
& \boldsymbol x = (x^1, x^2) = \biggl(\sin\Bigl(\frac{0}{6}\pi\Bigr), \sin\Bigl(\frac{1}{6}\pi\Bigr)\biggr) \to y^2 = 0.84290507 \ \ (d^2 = 0.86602540) \tag{14} \\
& \boldsymbol x = (x^1, x^2) = \biggl(\sin\Bigl(\frac{4}{6}\pi\Bigr), \sin\Bigl(\frac{5}{6}\pi\Bigr)\biggr) \to y^2 = -0.02663726 \ \ (d^2 = 0.0) \tag{15}
\end{align}

時系列データとして入力すると,正しい予測結果が得られています.
これは,時系列の情報を保持する再帰的な構造を持つRNNだからできることです.

sin波の予測

最後に,sin波の予測を行います.
以下に簡単なアルゴリズムを示します.

  1. 短い系列データ $\boldsymbol x = (x^1, x^2, \cdots, x^{T'})$ を入力し,出力 $y^{T'}$ を得ます.
  2. $y^{T'}$ を $x^{T'+1}$ として入力し,出力 $y^{T'+1}$ を得ます.
  3. 2.を繰り返し,順次結果を得ます.

以下の図が結果になります.
青点が短い系列データ,緑点が予測により得られた結果になります.
グラフを見るとsin波を正しく予測できていることが分かります.

sin.png

おわりに

RNNを実装できました.
次は,LSTMを勉強したいと思います.

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away