78
61

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

[TensorFlow/Keras] 好きな構造のRNNを組み立てるまでの道のり

Last updated at Posted at 2020-04-11

はじめに

時系列データを入力にとり、今の時刻の入力に加えて前の時刻の「状態」も使って出力を決めるニューラルネットワークの形態に RNN (Recurrent Neural Network) があります。LSTM(Long Short-Term Memory, 長・短期記憶)が有名でしょうか。
時系列データとは、動画やテキストといった、列全体として意味を持つようなデータです。普通のニューラルネットワークは画像や文字といった、形式の決まったある1つのデータを入力に取るわけですが、それらが並んだ動画やテキストを扱うときには、個々の画像(フレーム)や文字はもちろん、その並びにも大きな意味がありますね。このようなデータをうまく扱う構造がRNNというわけです。

ただ普通の全結合層などと違って正直とっつきにくいと思います。私もそうです。

というわけで、まずはRNNが何をするものかを理解して、次に前の時刻の「状態」を使うようなRNNを自分で組み立てられるようになりたいと思います。

検証環境

  • Ubuntu 18.04
  • Python 3.6.9
  • TensorFlow 2.1.0 (CPU)

目標

以下のような悩みを解決できればと思います(私自身が現に悩んでいたので…)。

  • RNNSimpleRNNSimpleRNNCell の違いがよく分からない
  • LSTM とかのレイヤーの中身をちゃんと理解したい
  • 論文の実験を追試するために自分で標準以外のRNNレイヤーを書きたい

参考ページ1に書かれているようなネットワーク構造(グラフ)や数式を見て、Kerasを使って自分でRNNを組み立てられるようになるとよいですね。

逆に、以下のようなことはこのページでは扱いません。もしかすると別の機会に記事にするかもしれませんが。

  • 自分の解きたい問題に合わせてネットワーク構造を考える方法
  • 可変長(サンプルごとに長さがバラバラ)の入力を扱う方法
  • return_state とか stateful とかの使い方
  • Bidirectional RNNの使い方

RNNの基本

RNNの基本は**「出力が、今の時刻の入力と、前の時刻の「状態」に依存する」**ということです。普通の全結合層や畳み込み層では、出力は入力のみに依存して決まりますが、RNNではそれ以前の入力の情報も使えるという違いがあります。次の時刻に持っていきたい「状態」を何にするかは、自分で決めることができます。

以下の左側の図のように、RNNは再帰的な構造を持ったネットワーク(セル)として表現することができるわけですが、その具体的な操作は右側の図のようにループを展開した形で理解することができます(引用元2)。

image.png

  • 入力: $x_1, x_2, ..., x_t, ...$

が与えられたとき、

  • 出力: $o_1, o_2, ..., o_t, ...$
  • 状態: $s_1, s_2, ..., s_t, ...$

\begin{align}
s_t &= f(Ux_t + Ws_{t-1} + b) \\
o_t &= h(Vs_t)
\end{align}

のように定めます。
ここで $U, V, W$ は行列、$b$ は列ベクトルで、レイヤーの重み(学習させるパラメータ)となります。$f, h$ は活性化関数です。入出力および状態 $x_t, o_t, s_t$ も列ベクトルです。

一番簡単なRNN

まずはRNNを触ってみましょう。
簡単なRNNとして、入力された数列の部分和(先頭からその時点までの値をすべて足したもの)を逐次出力していくようなネットワークを考えてみます。このとき、部分和を「状態」と定義し、状態をそのまま出力することにします。例えば、入力に対する出力や状態は以下の表のように推移していきます。

$t$ 1 2 3 4 5 6 7 ...
$x_t$ 1 3 2 4 1 0 1
$s_t$ 1 4 6 10 11 11 12
$o_t$ 1 4 6 10 11 11 12

TensorFlow + Kerasでは、tf.keras.layers.SimpleRNN というレイヤーを使うことで

\begin{align}
o_t = s_t = f(Ux_t + Ws_{t-1} + b) \tag{1}
\end{align}

という形のネットワークを定義することができます。$f(x) = x$ と定めて、数列とその部分和の列を学習させると、

\begin{align}
o_t = s_t = Ux_t + Ws_{t-1} + b
\end{align}

の重みが $U=W=1, ; b=0$ に近づいていくことが期待されます(今回は1次元の値を扱うので、$U, W, b$ はスカラーと考えて構いません)。

以下のコードで実際に学習を試します。長さ30の乱数列と、そこから計算した部分和の列を与えて学習させています。

first.py
import tensorflow as tf 
import numpy as np 
from tensorflow.keras import Sequential 
from tensorflow.keras.layers import SimpleRNN
from tensorflow.keras.optimizers import SGD

tf.random.set_seed(111)
np.random.seed(111)

model = Sequential([
    SimpleRNN(1, activation=None, input_shape=(None, 1), return_sequences=True)
])
model.compile(optimizer=SGD(lr=0.0001), loss="mean_squared_error")

n = 51200
x = np.random.random((n, 30, 1))
y = x.cumsum(axis=1)

model.fit(x, y, batch_size=512, epochs=100)

model.layers[0].weights
# [<tf.Variable 'simple_rnn/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[0.6021545]], dtype=float32)>,
#  <tf.Variable 'simple_rnn/recurrent_kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.0050855]], dtype=float32)>,
#  <tf.Variable 'simple_rnn/bias:0' shape=(1,) dtype=float32, numpy=array([0.20719269], dtype=float32)>]

model.predict(np.ones((1, 30, 1)) * 0.5).flatten()
# array([ 0.5082699,  1.0191246,  1.5325773,  2.0486412,  2.5673294,
#         3.0886555,  3.6126328,  4.1392746,  4.6685944,  5.2006063,
#         5.7353234,  6.27276  ,  6.8129296,  7.3558464,  7.901524 ,
#         8.449977 ,  9.00122  ,  9.555265 , 10.112128 , 10.6718235,
#        11.2343645, 11.799767 , 12.368044 , 12.939212 , 13.513284 ,
#        14.090276 , 14.670201 , 15.253077 , 15.838916 , 16.427734 ],
#       dtype=float32)

誤差が大きく見えますが、サンプルなので雰囲気がつかめればよいでしょう(出力の精度はここでは追求しません)。ここでは、学習の結果として

\begin{align}
o_t = s_t = 0.6022x_t + 1.0051s_{t-1} + 0.2072
\end{align}

が得られたことになります。

SimpleRNNの解説

SimpleRNN(1, activation=None, input_shape=(None, 1), return_sequences=True)
  • 最初の 1 は、$o_t, s_t$ の次元数です。今回はスカラーなので1を指定しました。
  • activation は、式(1)の $f$ に相当するものです。今回は恒等関数なので None としました。Dense などの場合はデフォルトが恒等関数なのですが、RNN系ではデフォルトが tanh になっていますのでご注意ください。
  • input_shape は、(None, dimension) の形をとります。最初の None は各入力列の長さに対応します(可変長の入力を受け付けるため None になっています)。2つ目は $x_t$ の次元数(今回は1)です。
  • return_sequences=True は、レイヤーの出力として各時刻の出力の列を返すことを指示します。この指定により、例えば長さ30の列に対する出力の形状は (batch_size, 30, 1) となります。これが False の場合、レイヤーの出力は最終時刻の出力(今回なら $o_{30}$)だけとなり、出力の形状が (batch_size, 1) となります。問題設定と学習データのフォーマットによって使い分けてください。

詳細は公式ドキュメントをご覧ください。
tf.keras.layers.SimpleRNN | TensorFlow Core v2.1.0

RNNを用いた書き換え

先ほどのモデルを作るコードは、以下と等価になります。

from tensorflow.keras.layers import RNN, SimpleRNN, SimpleRNNCell

model = Sequential([
    #SimpleRNN(1, activation=None, input_shape=(None, 1), return_sequences=True) 
    RNN(SimpleRNNCell(1, activation=None), input_shape=(None, 1), return_sequences=True)
])

The cell is the inside of the for loop of a RNN layer. Wrapping a cell inside a tf.keras.layers.RNN layer gives you a layer capable of processing batches of sequences, e.g. RNN(LSTMCell(10)).

Recurrent Neural Networks (RNN) with Keras | TensorFlow Core

SimpleRNNCell で単一のサンプルに対する操作(セル)を定義し、それを RNN() で囲むことによってバッチを処理するレイヤーを定義しています。

言い換えれば、SimpleRNNCell に相当するサンプル単位の処理を自分で定義して RNN() で囲むことにより、好きな構造のRNNを定義できるはずです。

ここまでの内容で、最初に述べた「RNNSimpleRNNSimpleRNNCell の違い」が分かってきました。

SimpleRNNCell の中身を見てみよう

自分でイメージしたRNNを作るための準備として、まずは既存の SimpleRNNCell が何をしているのか見てみましょう。自分で書くときには、まずは既存の処理を真似して作るのが近道のはずだからです。

(また、tf.keras.layers.RNN | TensorFlow Core v2.1.0 の Example も参考になると思います)

SimpleRNNCell のソースコードは以下にあります。
tensorflow/recurrent.py at v2.1.0 · tensorflow/tensorflow · GitHub

この中から抜粋して中身を見ていきましょう。

まずはクラス定義で Layer を継承します。DropoutRNNCellMixin は、Dropoutをサポートするために継承しているようですが、本筋から外れますのでここでは触れません。

recurrent.py
class SimpleRNNCell(DropoutRNNCellMixin, Layer):

build() では、レイヤーに必要な重みを add_weight で定義しています。RNNに限らず Dense() などでも同じことをしています。
式(1) との対応としては、kernel が $U$、recurrent_kernel が $W$, bias が $b$ に相当します。

そして call() で実際の処理を定義します。ここが最も重要です。

recurrent.py
  def call(self, inputs, states, training=None):
    prev_output = states[0]
    dp_mask = self.get_dropout_mask_for_cell(inputs, training)
    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
        prev_output, training)

    if dp_mask is not None:
      h = K.dot(inputs * dp_mask, self.kernel)
    else:
      h = K.dot(inputs, self.kernel)
    if self.bias is not None:
      h = K.bias_add(h, self.bias)

    if rec_dp_mask is not None:
      prev_output = prev_output * rec_dp_mask
    output = h + K.dot(prev_output, self.recurrent_kernel)
    if self.activation is not None:
      output = self.activation(output)

    return output, [output]

inputs に入力 $x_t$ が、states に(前の時刻で生成された)状態 $s_{t-1}$ が入ってきます。
状態を複数持たせられるようにするため、states は各変数のリストとして渡ってきます。そこで、最初に states[0] だけを取り出しています。
Dropout関係の処理は置いておくとして、本筋の部分を抜き出してみると

h = K.dot(inputs, self.kernel)
if self.bias is not None:
  h = K.bias_add(h, self.bias)
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.activation is not None:
  output = self.activation(output)
return output, [output]

これだけですね。
TensorFlowでは入出力の各サンプルは行ベクトルで表されるので、行列積の順序が逆になっていますが、式(1)と対応が取れることが分かります。
最後の return で、レイヤーの出力 $o_t$ と、次の時刻に持っていきたい状態 $s_t$ を返しています。ここで渡した状態が、次の時刻の call() で受け取れるというわけです。状態は引数と同様にリストで返します。ここで複数の状態を返すと、次の時刻に複数の状態が受け取れます。

LSTMを見てみよう

次はもう少し複雑な例として、LSTMのレイヤーを見ていきます。
まずは先ほどと同じ問題設定で、SimpleRNN の代わりに LSTM を使ってみます。

lstm.py
import tensorflow as tf 
import numpy as np 
from tensorflow.keras import Sequential 
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import SGD

tf.random.set_seed(111)
np.random.seed(111)

model = Sequential([
    LSTM(1, activation=None, input_shape=(None, 1), return_sequences=True)
])
model.compile(optimizer=SGD(lr=0.0001), loss="mean_squared_error")

n = 51200
x = np.random.random((n, 30, 1))
y = x.cumsum(axis=1)

model.fit(x, y, batch_size=512, epochs=100)

model.layers[0].weights                                                   
# [<tf.Variable 'lstm/kernel:0' shape=(1, 4) dtype=float32, numpy=
#  array([[ 0.11471224, -0.15296884,  0.82662594, -0.14256166]],
#        dtype=float32)>,
#  <tf.Variable 'lstm/recurrent_kernel:0' shape=(1, 4) dtype=float32, numpy=
#  array([[ 0.10575113,  0.16468772, -0.05777477,  0.20210776]],
#        dtype=float32)>,
#  <tf.Variable 'lstm/bias:0' shape=(4,) dtype=float32, numpy=array([0.4812489, 1.6566612, 1.1815464, 0.4349145], dtype=float32)>]

model.predict(np.ones((1, 30, 1)) * 0.5).flatten()
# array([ 0.59412843,  1.1486205 ,  1.6723596 ,  2.1724625 ,  2.6546886 ,
#         3.1237347 ,  3.5834525 ,  4.0370073 ,  4.486994  ,  4.93552   ,
#         5.38427   ,  5.8345466 ,  6.2873073 ,  6.7431927 ,  7.20255   ,
#         7.6654577 ,  8.131752  ,  8.601054  ,  9.072805  ,  9.546291  ,
#        10.0206785 , 10.495057  , 10.968457  , 11.439891  , 11.908364  ,
#        12.372919  , 12.832628  , 13.286626  , 13.734106  , 14.174344  ],
#       dtype=float32)

実は、使うだけなら以前の記事で使ってみたことはあるのです(問題設定まで同じ…)。
KerasのRNN (LSTM) で return_sequences=True を試してみる - Qiita
今回は、もう少し実装部分を掘り下げて見ていきましょう。

こちらも、SimpleRNN と同様に、RNN とセルに分離して等価な処理を実現できます3

from tensorflow.keras.layers import LSTM, RNN, LSTMCell

model = Sequential([
    # LSTM(1, activation=None, input_shape=(None, 1), return_sequences=True)
    RNN(LSTMCell(1, activation=None), input_shape=(None, 1), return_sequences=True)
])

この後、セル部分 LSTMCell の処理に注目していきます。

LSTMの式

実装を見る前に、まずはLSTMの処理を確認しておきましょう。各ゲートの理論的な意味などにはここでは触れません。
(数式と図は引用元1

20170506172239.png

\begin{align}
o_t &= σ \left( W_ox_t + R_oh_{t-1} + b_o \right) \tag{2.1}\\
f_t &= σ \left( W_fx_t + R_fh_{t-1} + b_f \right) \tag{2.2}\\
i_t &= σ \left( W_ix_t + R_ih_{t-1} + b_i \right) \tag{2.3}\\
z_t &= \tanh \left( W_zx_t + R_zh_{t-1} + b_z \right) \tag{2.4}\\
c_t &= i_t \otimes z_t+c_{t-1} \otimes f_t  \tag{2.5}\\
h_t &= o_t \otimes \tanh(c_t) \tag{2.6}
\end{align}

ただし $\otimes$ は要素ごとの積、$\sigma$ はシグモイド関数、$\tanh$ は双曲線正接(ハイパボリックタンジェント)です。

LSTMCellの解読

式(2.1)から(2.6)を踏まえて、LSTMCell の実装を見ていきましょう。
tensorflow/recurrent.py at v2.1.0 · tensorflow/tensorflow · GitHub

クラスの書き出しは SimpleRNN と同じです。

recurrent.py
class LSTMCell(DropoutRNNCellMixin, Layer):

build() で重みを定義しています。

recurrent.py
  def build(self, input_shape):
    default_caching_device = _caching_device(self)
    input_dim = input_shape[-1]
    self.kernel = self.add_weight(
        shape=(input_dim, self.units * 4),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        caching_device=default_caching_device)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units * 4),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint,
        caching_device=default_caching_device)

    if self.use_bias:
      if self.unit_forget_bias:

        def bias_initializer(_, *args, **kwargs):
          return K.concatenate([
              self.bias_initializer((self.units,), *args, **kwargs),
              initializers.Ones()((self.units,), *args, **kwargs),
              self.bias_initializer((self.units * 2,), *args, **kwargs),
          ])
      else:
        bias_initializer = self.bias_initializer
      self.bias = self.add_weight(
          shape=(self.units * 4,),
          name='bias',
          initializer=bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          caching_device=default_caching_device)
    else:
      self.bias = None
    self.built = True

ここで細かいところは置いておいて、self.units * 4 という記述が数箇所にある点にご注意ください。
実は kernel には $W_o, W_f, W_i, W_z$ の4つの行列を連結したものがまとめて入っています4。同様に、recurrent_kernel は $R_o, R_f, R_i, R_z$ の4つをまとめて持っていますし、biasb_o, b_f, b_i, b_z の4つをまとめて持っています。もちろん、それぞれ4つの変数に分けて(全部で12個)持っておいても別に間違いではありません。
例によって数式で書くときとは行・列が逆になっているので、数式を見ると行の数が4倍になりそうですが、コード上では列の数を4倍にしています。

call() が本体部分です。説明を簡単にするため、implementation=1 の方の処理だけ示します。

recurrent.py
  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state
    c_tm1 = states[1]  # previous carry state
    ()
      if 0 < self.dropout < 1.:
        ()
      else:
        inputs_i = inputs
        inputs_f = inputs
        inputs_c = inputs
        inputs_o = inputs
      k_i, k_f, k_c, k_o = array_ops.split(
          self.kernel, num_or_size_splits=4, axis=1)
      x_i = K.dot(inputs_i, k_i)
      x_f = K.dot(inputs_f, k_f)
      x_c = K.dot(inputs_c, k_c)
      x_o = K.dot(inputs_o, k_o)
      if self.use_bias:
        b_i, b_f, b_c, b_o = array_ops.split(
            self.bias, num_or_size_splits=4, axis=0)
        x_i = K.bias_add(x_i, b_i)
        x_f = K.bias_add(x_f, b_f)
        x_c = K.bias_add(x_c, b_c)
        x_o = K.bias_add(x_o, b_o)

      if 0 < self.recurrent_dropout < 1.:
        ()
      else:
        h_tm1_i = h_tm1
        h_tm1_f = h_tm1
        h_tm1_c = h_tm1
        h_tm1_o = h_tm1
      x = (x_i, x_f, x_c, x_o)
      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
    ()
    h = o * self.activation(c)
    return h, [h, c]

式(2.1)から(2.6)では、**前の時刻の情報として $h_{t-1}, c_{t-1}$ が使われています。よって、この2つは両方とも状態として持たなければなりません。**前述のように、状態をリストで受け渡しすることにより、2つ以上の状態を扱うことができます。

前半部分では、$W_ox_t + b_o, W_fx_t + b_f, W_ix_t + b_i, W_zx_t + b_z$ の4つの値を計算しています。(式(2.4)の $W_z, b_z$ は、コードでは k_c, b_c と添字が違う点にご注意ください)
後半部分では、_compute_carry_and_output() を使って $c_t, o_t$ の値を計算しています。
最後に $h_t$ を計算します。ここで $h_t$ をそのまま出力するとともに、$h_t, c_t$ は次の時刻での計算に使用するため状態として返しておきます。activation のデフォルト値は式(2.6)のように tanh です。

_compute_carry_and_output() は以下のように定義されています。

recurrent.py
  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
    """Computes carry and output using split kernels."""
    x_i, x_f, x_c, x_o = x
    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
    i = self.recurrent_activation(
        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
    f = self.recurrent_activation(x_f + K.dot(
        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
    c = f * c_tm1 + i * self.activation(x_c + K.dot(
        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
    o = self.recurrent_activation(
        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
    return c, o

それぞれ、recurrent_kernel の一部だけを取り出したものを使って行列積を計算しています。行列積の部分は要するに $R_ih_{t-1}, R_fh_{t-1}, R_zh_{t-1}, R_oh_{t-1}$ です。
x_i, x_f, x_c, x_o には、計算済みの $W_ix_t + b_i, W_fx_t + b_f, W_zx_t + b_z, W_ox_t + b_o$ の値が入っているので、これで活性化関数の中身が計算できます。
活性化関数 recurrent_activation は式(2.1)から(2.3)のシグモイド関数に相当しますが、デフォルト値としては hard_sigmoid5 となっているようです。あとは式の定義通りですかね。

自分でRNNを組み立ててみる

ここまでの内容を踏まえて、論文などで提案されているLSTMの派生形などを自分で実装して試してみたい!と思ったときの流れを一例紹介します。

簡単な例がよいと思うので、Wu (2016)6 で提案されている Simplified LSTM (S-LSTM) を試してみます。

まず元論文の数式を引用します。ただし、このページの他の式と記法を合わせるため、添字の位置を変更したり、$\delta, g$ と一般化して書かれているものを $\sigma, \tanh$ と書くなどの変更を行っています。

\begin{align}
f_t &=\sigma(W_fx_t+R_fh_{t−1}+b_f) \\
c_t &=f_t \otimes c_{t−1}+ (1−f_t) \otimes \tanh (W_c x_t+R_ch_{t−1}+b_c) \\
h_t &=\tanh (c_t)
\end{align}

状態はどれ?

まずは数式から状態として持つべきものを探しましょう。
前の時刻の情報を使っている、すなわち添字を $t-1$ で参照している変数を状態として持っておかなければいけません。よって、今回は $h_t, c_t$ を状態として持ちます。

重みはどれ?

重み(学習させたいパラメータ)として定義するものは $W_f, R_f, b_f, W_c, R_c, b_c$ ですね。
普通のLSTMと比べて、重みの数が半分になっています。

実装

LSTMCellLayer を継承していますが、自分で作るときには tf.keras.layers.AbstractRNNCell を継承するのがよいようです。
tf.keras.layers.AbstractRNNCell | TensorFlow Core v2.1.0

This is the base class for implementing RNN cells with custom behavior.

build() については、LSTMの実装をベースに改変するとこんな感じでしょうか。* 4* 2 に変えて、Dropoutなど直接関係しない部分は除いています。

  def build(self, input_shape):
    input_dim = input_shape[-1]
    self.kernel = self.add_weight(
        shape=(input_dim, self.units * 2),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units * 2),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint)

    if self.use_bias:
      self.bias = self.add_weight(
          shape=(self.units * 2,),
          name='bias',
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint)
    else:
      self.bias = None
    self.built = True

call() については implementation=1 相当の処理のみ実装します。たぶんこんな感じ。
なお、inputs とか states などは tf.Tensor になっているので、np.dot のような ndarray 向けの処理は使わずに、tf.math, tf.linalg, tf.keras.backend などに含まれている Tensor を扱う関数を使って処理してください。
TensorFlowのTensorオブジェクトに慣れたい - Qiita

  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state
    c_tm1 = states[1]  # previous carry state

    k_f, k_c = array_ops.split(
          self.kernel, num_or_size_splits=2, axis=1)
    x_f = K.dot(inputs, k_f)
    x_c = K.dot(inputs, k_c)
    if self.use_bias:
      b_f, b_c = array_ops.split(
          self.bias, num_or_size_splits=2, axis=0)
      x_f = K.bias_add(x_f, b_f)
      x_c = K.bias_add(x_c, b_c)

    f = self.recurrent_activation(x_f + K.dot(
        h_tm1, self.recurrent_kernel[:, :self.units]))
    c = f * c_tm1 + (1 - f) * self.activation(x_c + K.dot(
        h_tm1, self.recurrent_kernel[:, self.units:]))

    h = self.activation(c)
    return h, [h, c]

全体のコード

今回の部分和学習タスクでは $h_t =\tanh (c_t)$ ではなく $h_t = c_t$ として使っています。

slstm.py
import tensorflow as tf 
import numpy as np 
from tensorflow.keras import Sequential 
from tensorflow.keras.layers import RNN, AbstractRNNCell
from tensorflow.keras.optimizers import SGD
from tensorflow.python.keras import activations, constraints, initializers, regularizers
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops

class SLSTMCell(AbstractRNNCell):
  def __init__(self,
               units,
               activation='tanh',
               recurrent_activation='hard_sigmoid',
               use_bias=True,
               kernel_initializer='glorot_uniform',
               recurrent_initializer='orthogonal',
               bias_initializer='zeros',
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               **kwargs):

    super(SLSTMCell, self).__init__(**kwargs)
    self.units = units
    self.activation = activations.get(activation)
    self.recurrent_activation = activations.get(recurrent_activation)
    self.use_bias = use_bias

    self.kernel_initializer = initializers.get(kernel_initializer)
    self.recurrent_initializer = initializers.get(recurrent_initializer)
    self.bias_initializer = initializers.get(bias_initializer)

    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
    self.bias_regularizer = regularizers.get(bias_regularizer)

    self.kernel_constraint = constraints.get(kernel_constraint)
    self.recurrent_constraint = constraints.get(recurrent_constraint)
    self.bias_constraint = constraints.get(bias_constraint)

  @property
  def state_size(self):
    return [self.units, self.units]

  def build(self, input_shape):
    input_dim = input_shape[-1]
    self.kernel = self.add_weight(
        shape=(input_dim, self.units * 2),
        name='kernel',
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint)
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units * 2),
        name='recurrent_kernel',
        initializer=self.recurrent_initializer,
        regularizer=self.recurrent_regularizer,
        constraint=self.recurrent_constraint)

    if self.use_bias:
      self.bias = self.add_weight(
          shape=(self.units * 2,),
          name='bias',
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint)
    else:
      self.bias = None
    self.built = True

  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state
    c_tm1 = states[1]  # previous carry state

    k_f, k_c = array_ops.split(
          self.kernel, num_or_size_splits=2, axis=1)
    x_f = K.dot(inputs, k_f)
    x_c = K.dot(inputs, k_c)
    if self.use_bias:
      b_f, b_c = array_ops.split(
          self.bias, num_or_size_splits=2, axis=0)
      x_f = K.bias_add(x_f, b_f)
      x_c = K.bias_add(x_c, b_c)

    f = self.recurrent_activation(x_f + K.dot(
        h_tm1, self.recurrent_kernel[:, :self.units]))
    c = f * c_tm1 + (1 - f) * self.activation(x_c + K.dot(
        h_tm1, self.recurrent_kernel[:, self.units:]))

    h = self.activation(c)
    return h, [h, c]

tf.random.set_seed(111)
np.random.seed(111)

model = Sequential([
    RNN(SLSTMCell(1, activation=None), input_shape=(None, 1), return_sequences=True)
])
model.compile(optimizer=SGD(lr=0.0001), loss="mean_squared_error")

n = 51200
x = np.random.random((n, 30, 1))
y = x.cumsum(axis=1)

model.fit(x, y, batch_size=512, epochs=100)

model.layers[0].weights
# [<tf.Variable 'rnn/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[-0.79614836,  0.03041089]], dtype=float32)>,
#  <tf.Variable 'rnn/recurrent_kernel:0' shape=(1, 2) dtype=float32, numpy=array([[0.08143749, 1.0668359 ]], dtype=float32)>,
#  <tf.Variable 'rnn/bias:0' shape=(2,) dtype=float32, numpy=array([0.6330045, 1.0431471], dtype=float32)>]

model.predict(np.ones((1, 30, 1)) * 0.5).flatten()
# array([ 0.47944844,  0.96489847,  1.4559155 ,  1.9520411 ,  2.4527955 ,
#         2.9576783 ,  3.466171  ,  3.9777386 ,  4.4918313 ,  5.007888  ,
#         5.5253367 ,  6.0435996 ,  6.5620937 ,  7.0802336 ,  7.597435  ,
#         8.113117  ,  8.626705  ,  9.13763   ,  9.645338  , 10.149284  ,
#        10.648943  , 11.143805  , 11.633378  , 12.117197  , 12.594816  ,
#        13.065814  , 13.529797  , 13.986397  , 14.435274  , 14.876117  ],
#       dtype=float32)

実装が合っているか少々自信がありませんが、それっぽく動いているのでOKとしましょう。

あれ、初期状態は?

実は、入力の先頭における状態(初期状態)は $h_0 = c_0 = 0$ となっています。これは AbstractRNNCell の中で定義されています。

recurrent.py
  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)

今回の問題ではそれで大丈夫なのですが、もし初期状態を違う値にしたい場合には、継承先の SLSTMCell クラスで get_initial_state() をオーバーライドすることにより変更できます。例えば $h_0=1$ から始めたい場合は以下のようになります。

(2021/11/20訂正) $h_t = c_t$ ですから、$c_0$ も1にしなければなりません。以下のコードを訂正しました。

  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
    h_0 = tf.ones([batch_size, self.units], dtype)
    c_0 = tf.ones([batch_size, self.units], dtype) # 2021/11/20訂正
    return [h_0, c_0]

学習データのラベルも「部分和+1」になるように変更して学習してみます。

n = 51200
x = np.random.random((n, 30, 1))
y = x.cumsum(axis=1) + 1

model.fit(x, y, batch_size=512, epochs=100)

model.predict(np.ones((1, 30, 1)) * 0.5).flatten()                      
# array([ 1.4924545,  1.9888796,  2.4888983,  2.9921224,  3.4981523,
#         4.0065784,  4.516983 ,  5.0289407,  5.542021 ,  6.0557885,
#         6.569805 ,  7.0836315,  7.596828 ,  8.108956 ,  8.619583 ,
#         9.128277 ,  9.634614 , 10.138178 , 10.63856  , 11.135363 ,
#        11.628199 , 12.116695 , 12.600491 , 13.079238 , 13.5526085,
#        14.020285 , 14.481971 , 14.937386 , 15.386269 , 15.828373 ],
#       dtype=float32)

なんとなくそれらしい結果が得られました。

まとめ

TensorFlow + KerasでのRNNの使い方と、論文を追試するときなどのためにRNNをカスタマイズする方法を書きました。
RNN, LSTMをブラックボックスとして使うだけなら難しくありませんが、内部処理を理解しようとすると(特に日本語の)参考資料が意外とないのですね。この記事でRNNやLSTMの扱いに対するハードルが下がれば幸いです。

  1. 今更聞けないLSTMの基本 - HELLO CYBERNETICS 2

  2. LeCun, Yann & Bengio, Y. & Hinton, Geoffrey. (2015). Deep Learning. Nature. 521. 436-44. 10.1038/nature14539.

  3. 実際には、LSTM を使うと高速なCuDNN実装が利用できる(場合がある)ので、単にLSTMを使いたいだけならわざわざ分離する意味はありません。

  4. 変数の数を減らす効果の他に、式(2.1)から(2.4)の活性化関数の中身をまとめて計算できるというメリットがあります。LSTMCell() を作成するときに implementation=2 を与えると、まとめて計算する実装が使われるようです。

  5. Kerasのhard_sigmoidが max(0, min(1, (0.2 * x) + 0.5)) である話 - Qiita

  6. Zhizheng Wu, Simon King. (2016). Investigating gated recurrent neural networks for speech synthesis, Proceedings of ICASSP 2016. (arXiv:1601.02539)

78
61
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
78
61

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?