3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【Python】LSTMのpredictと同じ計算をfrom scratchでやってみた【Keras】

Last updated at Posted at 2020-01-19

1. 概要

  • Kerasのmodel.predict()が具体的にどういう計算をしているかを理解することが目的です。
  • 学習済みのモデルからmodel.get_weights()を使ってウエイトを取得し、学習データにこれを適用してモデルの出力を計算します。
  • model.predict()の出力と同じ結果を得るのが目標です!
  • back propagationについては一切扱っていません。
  • Kerasで作ったモデルをPython以外の言語で実装したい場面があったのがきっかけです。LSTMはなんとなく理解していたつもりだったのですが、実装でハマったので笑

2. データとモデルの準備

データ

  • model.predict()と同じ結果を手動で得ることが目的なので、モデル自体の完成度や複雑さは必要ないので、まずは以下の様なサンプルデータを作成しました。
X = np.arange(24).reshape(4,3,2)
y = np.array([[0,1],[0,1],[0,1],[1,0]])
print(X.shape)
# => (4, 3, 2)
print(y.shape)
# => (4, 2)
print(X[0])
# => [[0 1]
#    [2 3]
#    [4 5]]
  • Xは学習データ、yは対応するラベルです。
  • X、yは4つのサンプルから成っています。
  • ひとつめのサンプルを見てみると、3×2の行列になっています。
  • これはデータとして、特徴量が2つあり、3つの時系列のデータを持っているということです。
  • データは古いものから新しいものの順に並んでおり、まず[0, 1]というデータが得られ、その後[2, 3]、最後に[4, 5]という観測が得られたというイメージです。

モデル

  • 以下の様なシンプルなモデルを作成しました。
from keras.layers import Input, Dense
from keras.models import Model
from keras.layers.recurrent import LSTM
import tensorflow as tf
from keras import backend

tf.reset_default_graph()
backend.clear_session()

inputs = Input(shape=[3,2])
x = LSTM(8, activation='tanh', recurrent_activation='sigmoid')(inputs)
outputs = Dense(2, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam',loss='categorical_crossentropy')
model.summary()
Screen Shot 2020-01-18 at 18.56.23.png
  • デフォルトではrecurrent_activation='hard_sigmoid'なのですが、いろいろと調べたところ、sigmoidが使われている文献・図表が多かったため、一旦recurrent_activation='sigmoid'としております。
  • あとは学習させればモデルの完成です。
history = model.fit(X, y, epochs=2, verbose=1)

3. get_weights()の理解

  • ここが分かりにくい所のひとつです。get_weights()でモデルのパラメータを取得できるのですが出力の意味が直感的にすぐ分かるというものではなかったです。(少なくとも僕にとって)
for weight in model.get_weights():
    print(weight.shape)
# => (2, 32)
#   (8, 32)
#   (32,)
#   (8, 2)
#   (2,)
  • KerasのDocumentationは以下の通りでした。もう少し詳しく説明が欲しい・・・

model.get_weights(): モデルの全ての重みテンソルをNumpy 配列を要素にもつリスト返します.

  • 配列の形からなんとなく想像はつくもののいろいろ調べたところ、始めの3つがLSTMのパラメータ、残りの2つがDenseレイヤのパラメータです。
  • このLSTMのパラメータを説明する前に、LSTMセルを再度確認します。
Screen Shot 2020-01-18 at 19.23.40.png
  • 以下では可能な限り上図の表記を借りてコードを書いていきます。実際に行われる計算は以下の様になります。
    $\quad i_t = \sigma(x_t W^i + h_{t-1} U^i + b^i)$
    $\quad f_t = \sigma(x_t W^f + h_{t-1} U^f + b^f)$
    $\quad \tilde{C}_t = {\rm tanh} (x_t W^g + h_{t-1} U^g + b^g)$
    $\quad o_t = \sigma(x_t W^o + h_{t-1} U^o + b^o)$
    $\quad C_t = f_t C_{t-1} + i_t \tilde{C}_t$
    $\quad h_t = {\rm tanh}(C_t)o_t$
  • get_weights()で得られるパラメータについての結論は以下の通りです。
W = model.get_weights()[0]
U = model.get_weights()[1]
b = model.get_weights()[2]
dense_W = model.get_weights()[3]
dense_b = model.get_weights()[4]

W_i, W_f, W_tC, W_o = W[:,:8], W[:,8:16], W[:,16:24], W[:,:24:]
U_i, U_f, U_tC, U_o = U[:,:8], U[:,8:16], U[:,16:24], U[:,:24:]
b_i, b_f, b_tC, b_o = b[:8], b[8:16], b[16:24], b[24:]
  • $W^i$、$W^f$、$W^g$、$W^o$がまとめてmodel.get_weights()[0]に格納されています。ここで8つずつパラメータを区切っているのは、LSTMのcell数を8つとしてモデルを作っているためです。

4. 計算!

  • ここまでくれば後は計算するのみです。
# ひとつめのサンプルを使って計算していきます。
_X = X[0]

# 活性化関数を定義しておきます。
def sigmoid(x):
    return(1.0/(1.0+np.exp(-x)))
def relu(x):
    ret_x = x
    ret_x[ret_x<0] = 0
    return ret_x

# 最初のC,hの値は全て0となっています。
C = np.zeros((1,8))
h = np.zeros((1,8))

# LSTM部分
for i in range(len(_X)):
    x_t = _X[i]
    i_t = sigmoid(np.dot(x_t,W_i) + np.dot(h,U_i) + b_i)
    f_t = sigmoid(np.dot(x_t,W_f) + np.dot(h,U_f) + b_f)
    tC = np.tanh(np.dot(x_t,W_g) + np.dot(h,U_g) + b_g)
    o_t = sigmoid(np.dot(x_t,W_o) + np.dot(h,U_o) + b_o)
    C = f_t*C + i_t*tC
    h = np.tanh(C) * o_t

# Dense部分
output = np.dot(h,dense_W) + dense_b

# softmax計算
E = []
Esum = 0
for i in range(2):
    E.append(np.exp(output[0,i]))
    Esum += np.exp(output[0,i])
result = []
for i in range(2):
    result.append(E[i]/Esum)
    
print(result)
# => [0.5211381547054326, 0.4788618452945675]
  • model.predict()の結果を確認します。
print(model.predict(_X.reshape(1,3,2)))
# => [[0.5211382  0.47886187]]
  • 見事に一致!!
  • ポイントはget_weights()の戻り値の意味を理解することに尽きました。戻り値のそれぞれのパラメータについてモデルのどのパラメータを指しているのかを理解すること、そしてLSTMのパラメータについては4つのパラメータがくっついてひとつのパラメータになっていることがポイントでした。

5. [番外編]activationとrecurrent_activation

  • KerasのLSTMではactivationとrecurrent_activationを指定することができますが、それぞれどこの活性化関数を指しているか分かりにくいというのが正直な感想です。
  • せっかくmodel.predictをnumpyで手計算できるようになったので、この機会にactivationとrecurrent_activationが具体的にどこで使われる非線形変換なのかを確認してみます。
activationとrecurrent_activationを調べてみる。
  • まずはactivationについて調べます。
# activationをreluに変えたモデルを作成します。
tf.reset_default_graph()
backend.clear_session()

inputs = Input(shape=[3,2])
x = LSTM(8, activation='relu', recurrent_activation='sigmoid')(inputs) # relu!
outputs = Dense(2, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam',loss='categorical_crossentropy')
history = model.fit(X, y, epochs=2, verbose=1)


# numpyを用いた予測の出力を計算します。
C = np.zeros((1,8))
h = np.zeros((1,8))
_X = X[0]

W = model.get_weights()[0]
U = model.get_weights()[1]
b = model.get_weights()[2]
dense_W = model.get_weights()[3]
dense_b = model.get_weights()[4]

W_i, W_f, W_g, W_o = W[:,:8], W[:,8:16], W[:,16:24], W[:,24:]
U_i, U_f, U_g, U_o = U[:,:8], U[:,8:16], U[:,16:24], U[:,24:]
b_i, b_f, b_g, b_o = b[:8], b[8:16], b[16:24], b[24:]

for i in range(len(_X)):
    x_t = _X[i]
    i_t = sigmoid(np.dot(x_t,W_i) + np.dot(h,U_i) + b_i)
    f_t = sigmoid(np.dot(x_t,W_f) + np.dot(h,U_f) + b_f)
    tC = relu(np.dot(x_t,W_g) + np.dot(h,U_g) + b_g) # relu!
    o_t = sigmoid(np.dot(x_t,W_o) + np.dot(h,U_o) + b_o)
    C = f_t*C + i_t*tC
    h = relu(C) * o_t # relu!

output = np.dot(h,dense_W) + dense_b

E = []
Esum = 0
for i in range(2):
    E.append(np.exp(output[0,i]))
    Esum += np.exp(output[0,i])
result = []
for i in range(2):
    result.append(E[i]/Esum)

# 出力は以下のようになります。
print(result)
# => [0.5606417941538421, 0.4393582058461578]

# model.predict()の出力を確認します。
print(model.predict(_X.reshape(1,3,2)))
# => [[0.5606418 0.4393582]]
  • 次にrecurrent_activationについても見ていきます。
# activationをreluに変えたモデルを作成します。
tf.reset_default_graph()
backend.clear_session()

inputs = Input(shape=[3,2])
x = LSTM(8, activation='tanh', recurrent_activation='relu')(inputs) # relu!
outputs = Dense(2, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam',loss='categorical_crossentropy')
history = model.fit(X, y, epochs=2, verbose=1)


# numpyを用いた予測の出力を計算します。
C = np.zeros((1,8))
h = np.zeros((1,8))
_X = X[0]

W = model.get_weights()[0]
U = model.get_weights()[1]
b = model.get_weights()[2]
dense_W = model.get_weights()[3]
dense_b = model.get_weights()[4]

W_i, W_f, W_g, W_o = W[:,:8], W[:,8:16], W[:,16:24], W[:,24:]
U_i, U_f, U_g, U_o = U[:,:8], U[:,8:16], U[:,16:24], U[:,24:]
b_i, b_f, b_g, b_o = b[:8], b[8:16], b[16:24], b[24:]

for i in range(len(_X)):
    x_t = _X[i]
    i_t = relu(np.dot(x_t,W_i) + np.dot(h,U_i) + b_i) # relu!
    f_t = relu(np.dot(x_t,W_f) + np.dot(h,U_f) + b_f) # relu!
    tC = np.tanh(np.dot(x_t,W_g) + np.dot(h,U_g) + b_g)
    o_t = relu(np.dot(x_t,W_o) + np.dot(h,U_o) + b_o) # relu!
    C = f_t*C + i_t*tC
    h = np.tanh(C) * o_t

output = np.dot(h,dense_W) + dense_b

E = []
Esum = 0
for i in range(2):
    E.append(np.exp(output[0,i]))
    Esum += np.exp(output[0,i])
result = []
for i in range(2):
    result.append(E[i]/Esum)

# 出力は以下のようになります。
print(result)
# => [0.5115599582737976, 0.4884400417262024]

# model.predict()の出力を確認します。
print(model.predict(_X.reshape(1,3,2)))
# => [[0.51155996 0.48844004]]
  • activationを$f_a$、recurrent_activationを$f_{r}$とすると、
    $\quad i_t = f_a(x_t W^i + h_{t-1} U^i + b^i)$
    $\quad f_t = f_a(x_t W^f + h_{t-1} U^f + b^f)$
    $\quad \tilde{C}_t = f_r (x_t W^g + h_{t-1} U^g + b^g)$
    $\quad o_t = f_a(x_t W^o + h_{t-1} U^o + b^o)$
    $\quad C_t = f_t C_{t-1} + i_t \tilde{C}_t$
    $\quad h_t = f_r(C_t)o_t$
    となることが分かりました。
  • 上でも使ったのLSTMセルの図でいう**$\sigma$のところがactivationで指定した関数、$\rm tanh$となっているところがrecurrent_activationで指定した関数**となっているようです。
Screen Shot 2020-01-18 at 19.23.40.png
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?