LoginSignup
26
12

More than 3 years have passed since last update.

RNNの隠れ層の次元とか入力データの構造とかよくわからない人へ

Posted at

はじめに

自分の場合、機会がなくてRNN(リカーレントニューラルネットワーク)ってあまり使用しないんですよね。だから、一回勉強してもパラメータとか忘れちゃうんですよ。ところが、最近時系列の使用場面がありまして。

RNNは初めての人にとっては、データの準備やネットワークの構造が複雑でわかりにくいです。自分も時々混乱するので、基本事項ではありますが復習がてらメモを残しておきます。

ざっくりRNN

RNNは特に時系列データに対して特化したモデルで、入力データを再帰的にネットワークに入力します。

今、$n_i$次元の特徴ベクトルが時間に依存して変化しているとします。これを$\boldsymbol{x}_t, t=0,1,...,N$と書くことにします。原理的には無限の時間のデータ系列を考えられますが、実用的にはある時間範囲のデータになると思います。

さて、時系列の特徴をネットワークに反映させるためには、過去の情報をネットワーク内に保持しておく必要があります。そこで、過去にネットワークに入力・伝搬したデータを再度ネットワークに入力することで、過去の情報を踏まえてネットワークを更新することができます。

image.png

上図を展開して描くと次の図のようになります。

image.png

$U, V, W$は重み行列です。隠れ層$\boldsymbol{h}$の次元を$n_h$とすると行列$U$の形状は$(n_h\times n_i)$、行列$W$は正方行列で形状は$(n_h\times n_h)$、出力の次元を$n_o$とすると、行列$V$の形状は$(n_o\times n_h)$となります。

隠れ層は次のように書けます。

\boldsymbol{h}_t = f(U\boldsymbol{x}_t + W\boldsymbol{h}_{t-1} + \boldsymbol{b})

$f$は活性化関数、$\boldsymbol{b}$はバイアスです。この漸化式を展開していくことで無限の過去に遡っていけます。なお、重み$U, W$は共有されます。

image.png

ただ、実用上は無限の大きさのネットワークは構築できないので、ある程度の時間範囲で切ります。例えば上の図では、2期前の隠れ層まで入力し、3期前の$\boldsymbol{h}_{t-3}$はゼロにしておけば良いです。

上の図の場合は、ネットワークの形からわかるように、[$\boldsymbol{x}_
{t-2}$,$\boldsymbol{x}_
{t-1}$,$\boldsymbol{x}_
{t}$]を入力として$\boldsymbol{y}_
{t}$を出力します。したがって、元の時系列データ[$..., \boldsymbol{x}_
{t-3},\boldsymbol{x}_
{t-2},\boldsymbol{x}_
{t-1},\boldsymbol{x}_
{t}, ...$]を1期ずつずらして3つずつのブロックに加工します。

[[\boldsymbol{x}_
{t-3},\boldsymbol{x}_
{t-2},\boldsymbol{x}_
{t-1}] \rightarrow \boldsymbol{y}_
{t-1} \\
[\boldsymbol{x}_
{t-2},\boldsymbol{x}_
{t-1},\boldsymbol{x}_
{t}] \rightarrow \boldsymbol{y}_
{t} \\
[\boldsymbol{x}_
{t-1},\boldsymbol{x}_
{t},\boldsymbol{x}_
{t+1}] \rightarrow \boldsymbol{y}_
{t+1} \\
...
]

このブロックの大きさ(今の場合3)を$\tau$として、ブロックの数を$M$とすれば、Kerasの場合にはデータのshapeは$(M, \tau, n_i)$となります。

TensorFlowで再帰的な隠れ層を実現するにはSimpleRNNCellを使用します。Kerasはデータを整形しておけば再帰構造を気にせず使うことができます。

from keras.layers import SimpleRNN

model = Sequential()
model.add(SimpleRNN(units=n_hidden,  # 隠れ層の次元
                    input_shape=(tau, n_in), # ブロックの大きさ,入力の次元
                    ))
model.add(Dense(n_out))
model.add(Activation('linear'))

おわりに

入力データの次元、各行列の形状など、それぞれのパラメータ間の関係を確認しました。RNNの発展形は奥が深いのでまたの機会に応用も含めて書ければと思います。

26
12
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
12