2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

tf.keras.layers.SimpleRNNの仕組みを手組みで確認する

Last updated at Posted at 2023-06-02

目的

Tensorflow.Kerasに用意されているRNNレイヤーの構造を、手組みの場合と比較しながら理解します。

対象

本稿では、RNNレイヤーの1つであるSimpleRNNを対象とします。
https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNN

前提

公式ドキュメント、専門書およびweb上の情報等を参考にリバースエンジニアリングを行いながら、結果を比較することで計算式、および重み行列とKerasレイヤーのパラメータとの対応関係を抽出しました。ソースコードを見て厳密に確認したわけではないため、100%正しい記載になっていない可能性がありますが、理解のためには十分かなと思います。

実行結果は以下の環境で計算した結果となります。
Python: 3.10.6
Tensorflow: 2.12.0 (tensorflow-aarch64)
Keras: 2.12.0
Numpy: 1.23.5

結果

天下り的ですが、SimpleRNNの計算式は以下のように表されることがわかりました。

\begin{align}
s_t &= f(W x_t + U s_{t-1} + b) \\
y_t &= s_t
\end{align}

ここで

  • $x_t$:入力値($K$次元ベクトル)
  • $s_t$:内部メモリ($N$次元ベクトル)(SimpleRNNではそのまま出力になる)
  • $y_t$:出力値($N$次元ベクトル)
  • $f$:Activation(デフォルトはtanh)
  • $W$:入力に対する重み($N\times K$行列)
  • $U$:内部メモリに対する重み($N\times N$行列)
  • $b$:バイアス項($N$次元ベクトル)

となります。

実証

SimpleRNNの計算が上式となることを、Kerasの出力結果と手組みの出力結果を比較することで確かめます。

まず入力値をランダムに生成します。

# 時刻ステップTのK次元データをS個生成
S = 3
T = 2
K = 4
input = np.random.randn(S, T, K)
print(input)

[[[-1.00559132 2.19790855 1.22426653 0.43110155]
[ 2.47757924 1.42999603 0.22461884 0.52406924]]

[[ 1.58063093 -0.65496626 0.66998741 -2.05034626]
[ 1.31754524 0.40744175 0.07545818 1.48894891]]

[[-0.46522181 2.16090139 0.71952939 -0.45572353]
[-1.99252205 0.38433699 0.49557174 -2.04901821]]]

Kerasの結果は次のようになります。return_sequences=Trueとして全ての時刻ステップの結果を取得します。また、デフォルトでゼロに初期化されてしまうbiasをbias_initializer='random_normal'として値を入れます。

# 内部メモリはN次元とする
N = 2
rnn = tf.keras.layers.SimpleRNN(N, bias_initializer='random_normal', return_sequences=True)
output = rnn(input)
print(output)

tf.Tensor(
[[[-0.9842801 0.7969481 ]
[ 0.93499136 -0.7715645 ]]

[[ 0.44224197 -0.9944355 ]
[ 0.7529196 -0.6964871 ]]

[[-0.9916976 -0.10067604]
[-0.999102 0.8626111 ]]], shape=(3, 2, 2), dtype=float32)

出力は$(S, T, N)$次元のテンソルとなります。

次に、SimpleRNNを手組みするために、$W,U,b$の重みをkeras.layerから取得します。

W = rnn.weights[0].numpy()
print(W)

[[ 0.83396554 -0.95233345]
[-0.9864013 -0.09958982]
[ 0.21518016 0.05979133]
[ 0.78888416 0.73306966]]

U = rnn.weights[1].numpy()
print(U)

[[ 0.19022751 -0.9817401 ]
[ 0.9817401 0.19022739]]

b = rnn.weights[2].numpy()
print(b)

[-0.01591891 -0.03780531]

最後に、各サンプルごとに、SimpleRNNの数式に従って計算し、各時刻ステップ毎に結果を出力します。このとき、行列積の順序に注意します。(入力値や重み行列が数式とは転置になっているため)

for x_s in input:
    s = np.zeros(N)
    for x_t in x_s:
        s = np.tanh(x_t @ W + s @ U + b)
        print(s)

[-0.98428005 0.79694807]
[ 0.93499136 -0.77156436]

[ 0.44224219 -0.99443547]
[ 0.75291959 -0.69648714]

[-0.99169759 -0.10067606]
[-0.99910198 0.86261118]

kerasの結果と一致していることが確認できました。

参考

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?