2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

リカレントニューラルネットワークによる計算結果の検証

Posted at

#はじめに

  • 「詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~」でRNNを勉強中。Amazon
  • RNNの基本的なアルゴリズムを理解するために、TensorFlowによる予測精度と対数損失の計算結果をPythonでも計算してみた。
  • 計算を簡略化するため、全てのモデルでBasicRNNCellを使用。
  • 例題のモデルは、いずれも固定長を前提としている。

#1. sequence-to-sequenceモデル

  • BasicRNNCellを使用したsequence-to-sequenceモデル
  • 「3桁の足し算」を答える例題。例えば"24+654"に対して"678"と答える。
  • X_validation[N_validation,7,12], Y_validation[N_validation,4,12], input_digits=7, output_digits=4, n_hidden=128

##計算結果の取得

param = []
for a in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
  param.append(sess.run(a))

##予測精度(acc)と対数損失(loss)の算出

# my library
def softmax(score):
  return np.exp(score) / np.sum(np.exp(score), axis=1, keepdims=True)
def one_hot(score):
  return np.eye(score.shape[1])[np.argmax(score, axis=1)]
w_en = param[0]
c_en = param[1]
v = param[2]
b = param[3]
w_de = param[4]
c_de = param[5]

state = np.zeros([N_validation, n_hidden])
for i in range(input_digits):
  state = np.tanh(np.dot(np.hstack((X_validation[:,i,:], state)), w_en) + c_en)

score = np.dot(state, v) + b
output = [my.softmax(score)]
outhot = [my.one_hot(score)]

for i in range(1, output_digits):
  state = np.tanh(np.dot(np.hstack((outhot[-1], state)), w_de) + c_de)
  score = np.dot(state, v) + b
  output.append(my.softmax(score))
  outhot.append(my.one_hot(score))

output = np.transpose(output, [1,0,2])
acc = np.mean(np.equal(np.argmax(output, 2), np.argmax(Y_validation, 2)))
loss = -np.mean(np.sum(np.log(np.clip(output, 1e-10, 1.0)) * Y_validation, axis=2))

#2. bidirectionalモデル

  • BasicRNNCellを使用したstatic_bidirectional_rnn()
  • MNISTの画像を長さ28の時系列データとみなして予測する例題。
  • X_validation[N_validation,28,28], Y_validation[N_validation,10], n_time=28, n_hidden=128

##予測精度(acc)と対数損失(loss)の算出

w_fw = param[0]
c_fw = param[1]
w_bw = param[2]
c_bw = param[3]
v = param[4]
b = param[5]

output_fw = []
state = np.zeros([N_validation, n_hidden])
for i in range(n_time):
  state = np.tanh(np.dot(np.hstack((X_validation[:,i,:], state)), w_fw) + c_fw)
  output_fw.append(state)

output_bw = []
state = np.zeros([N_validation, n_hidden])
for i in reversed(range(n_time)):
  state = np.tanh(np.dot(np.hstack((X_validation[:,i,:], state)), w_bw) + c_bw)
  output_bw.append(state)

outputs = tuple(np.hstack((fw, bw)) for fw, bw in zip(output_fw, reversed(output_bw)))

score = np.dot(outputs[-1], v) + b
acc = np.mean(np.equal(np.argmax(score, 1), np.argmax(Y_validation, 1)))
loss = -np.mean(np.sum(np.log(np.clip(my.softmax(score), 1e-10, 1.0)) * Y_validation, axis=1))

#3. attentionモデル

  • BasicRNNCellに基づいたsequence-to-sequenceモデルにAttentionCellWrapperを適用
  • 「3桁の足し算」を答える例題。例えば"24+654"に対して"678"と答える。
  • X_validation[N_validation,7,12], Y_validation[N_validation,4,12], input_digits=7, output_digits=4, n_hidden=128
  • パラメーターが22個必要。

##予測精度(acc)と対数損失(loss)の算出

w1_en = param[0]
c1_en = param[1]
w2_en = param[2]
c2_en = param[3]
attn_k = param[4]
attn_v = param[5]
attn_w = param[6]
attn_c = param[7]
w3_en = param[8]
c3_en = param[9]
w4_en = param[10]
c4_en = param[11]

state = np.zeros([N_validation, n_hidden])
attns = np.zeros([N_validation, n_hidden])
attn_states = np.zeros([N_validation, input_digits, n_hidden])

def my_attention(state, attn_states):
  hidden_features = np.dot(attn_states, attn_k[0,0])
  y = np.dot(state, attn_w) + attn_c
  y = np.expand_dims(y, 1)
  s = np.sum(attn_v * np.tanh(hidden_features + y), axis=2)
  a = my.softmax(s)
  attns = np.sum(np.expand_dims(a, 2) * attn_states, axis=1)
  return attns

for i in range(input_digits):
  inputs = np.dot(np.hstack((X_validation[:,i,:], attns)), w1_en) + c1_en
  output = np.tanh(np.dot(np.hstack((inputs, state)), w2_en) + c2_en)
  state = output
  attns = my_attention(state, attn_states)
  output = np.dot(np.hstack((output, attns)), w3_en) + c3_en
  attn_states = np.hstack((attn_states[:,1:,:], np.expand_dims(output, 1)))

score = np.dot(output, w4_en) + c4_en
outputs = [my.softmax(score)]
outhots = [my.one_hot(score)]

w1_de = param[12]
c1_de = param[13]
w2_de = param[14]
c2_de = param[15]
attn_k = param[16]
attn_v = param[17]
attn_w = param[18]
attn_c = param[19]
w3_de = param[20]
c3_de = param[21]

for i in range(1, output_digits):
  inputs = np.dot(np.hstack((outhots[-1], attns)), w1_de) + c1_de
  output = np.tanh(np.dot(np.hstack((inputs, state)), w2_de) + c2_de)
  state = output
  attns = my_attention(state, attn_states)
  output = np.dot(np.hstack((output, attns)), w3_de) + c3_de
  attn_states = np.hstack((attn_states[:,1:,:], np.expand_dims(output, 1)))
  score = np.dot(output, w4_en) + c4_en
  outputs.append(my.softmax(score))
  outhots.append(my.one_hot(score))

outputs = np.transpose(outputs, [1,0,2])
acc = np.mean(np.equal(np.argmax(outputs, 2), np.argmax(Y_validation, 2)))
loss = -np.mean(np.sum(np.log(np.clip(outputs, 1e-10, 1.0)) * Y_validation, axis=2))
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?