TensorFlow MNISTをRNNでやってみる

MNISTをRNN(Recurrent Neural Networks)でやってみました。




LSTM(Long short-term memory)は、RNN(Recurrent Neural Network)の拡張として1995年に登場した、時系列データ(sequential data)に対するモデル、あるいは構造(architecture)の1種。

最も大きな特長は、従来のRNNでは学習できなかった長期依存(long-term dependencies)を学習可能であるところ




juoyter notebookを使用しました。

import tensorflow as tf
from tensorflow.contrib import rnn

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])

各ステップ毎に分割されたテンソルに変換する。tf.unstack で [バッチサイズ × 28] の28個のテンソルを持つPythonのlistに変換。

def RNN(x):
    x = tf.unstack(x, 28, 1)

    # LSTMの設定
    lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)

    # モデルの定義。各タイムステップの出力値と状態が返される
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # 重みとバイアスの設定
    weight = tf.Variable(tf.random_normal([128, 10]))
    bias = tf.Variable(tf.random_normal([10]))

    return tf.matmul(outputs[-1], weight) + bias

コスト関数を定義。今回は、学習させるためにクロスエントロピー誤差関数とAdam Optimizerを使用。

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

# 評価用
correct_pred = tf.equal(tf.argmax(preds, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))


batch_size = 128
n_training_iters = 100000
with tf.Session() as sess:
    step = 1
    # Keep training until reach max iterations
    while step * batch_size < n_training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # next_batchで返されるbatch_xは[batch_size, 784]のテンソルなので、batch_size×28×28に変換します。
        batch_x = batch_x.reshape((batch_size, 28, 28))
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % 10 == 0:
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            print('step: {} / loss: {:.6f} / acc: {:.5f}'.format(step, loss, acc))
        step += 1

    # テスト
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, 28, 28))
    test_label = mnist.test.labels[:test_len]
    test_acc = sess.run(accuracy, feed_dict={x: test_data, y: test_label})
    print("Test Accuracy: {}".format(test_acc))

step: 10 / loss: 1.751291 / acc: 0.42969
step: 20 / loss: 1.554639 / acc: 0.46875
step: 30 / loss: 1.365595 / acc: 0.57031
step: 40 / loss: 1.176470 / acc: 0.60156
step: 50 / loss: 0.787636 / acc: 0.75781
step: 60 / loss: 0.776735 / acc: 0.75781
step: 70 / loss: 0.586180 / acc: 0.79688
step: 80 / loss: 0.692503 / acc: 0.80469
step: 90 / loss: 0.550008 / acc: 0.82812
step: 100 / loss: 0.553710 / acc: 0.86719
step: 110 / loss: 0.423268 / acc: 0.86719
step: 120 / loss: 0.462931 / acc: 0.82812
step: 130 / loss: 0.365392 / acc: 0.85938
step: 140 / loss: 0.505170 / acc: 0.85938
step: 150 / loss: 0.273539 / acc: 0.91406
step: 160 / loss: 0.322731 / acc: 0.87500
step: 170 / loss: 0.531190 / acc: 0.85156
step: 180 / loss: 0.318869 / acc: 0.90625
step: 190 / loss: 0.351407 / acc: 0.86719
step: 200 / loss: 0.232232 / acc: 0.92188
step: 210 / loss: 0.245849 / acc: 0.92969
step: 220 / loss: 0.312085 / acc: 0.92188
step: 230 / loss: 0.276383 / acc: 0.89844
step: 240 / loss: 0.196890 / acc: 0.94531
step: 250 / loss: 0.221909 / acc: 0.91406
step: 260 / loss: 0.246551 / acc: 0.92969
step: 270 / loss: 0.242577 / acc: 0.92188
step: 280 / loss: 0.165623 / acc: 0.94531
step: 290 / loss: 0.232382 / acc: 0.94531
step: 300 / loss: 0.159169 / acc: 0.92969
step: 310 / loss: 0.229053 / acc: 0.92969
step: 320 / loss: 0.384319 / acc: 0.90625
step: 330 / loss: 0.151922 / acc: 0.93750
step: 340 / loss: 0.153512 / acc: 0.95312
step: 350 / loss: 0.113470 / acc: 0.96094
step: 360 / loss: 0.192841 / acc: 0.93750
step: 370 / loss: 0.169354 / acc: 0.96094
step: 380 / loss: 0.217942 / acc: 0.94531
step: 390 / loss: 0.151771 / acc: 0.95312
step: 400 / loss: 0.139619 / acc: 0.96094
step: 410 / loss: 0.236149 / acc: 0.92969
step: 420 / loss: 0.131790 / acc: 0.94531
step: 430 / loss: 0.172267 / acc: 0.96094
step: 440 / loss: 0.182242 / acc: 0.93750
step: 450 / loss: 0.131859 / acc: 0.94531
step: 460 / loss: 0.216793 / acc: 0.92969
step: 470 / loss: 0.082368 / acc: 0.96875
step: 480 / loss: 0.064672 / acc: 0.96094
step: 490 / loss: 0.119717 / acc: 0.96875
step: 500 / loss: 0.169831 / acc: 0.94531
step: 510 / loss: 0.106913 / acc: 0.98438
step: 520 / loss: 0.073209 / acc: 0.97656
step: 530 / loss: 0.131819 / acc: 0.96875
step: 540 / loss: 0.210754 / acc: 0.94531
step: 550 / loss: 0.141051 / acc: 0.93750
step: 560 / loss: 0.217726 / acc: 0.94531
step: 570 / loss: 0.121927 / acc: 0.96094
step: 580 / loss: 0.130969 / acc: 0.94531
step: 590 / loss: 0.125145 / acc: 0.95312
step: 600 / loss: 0.193178 / acc: 0.95312
step: 610 / loss: 0.114959 / acc: 0.95312
step: 620 / loss: 0.129038 / acc: 0.96094
step: 630 / loss: 0.151445 / acc: 0.95312
step: 640 / loss: 0.120206 / acc: 0.96094
step: 650 / loss: 0.107941 / acc: 0.96875
step: 660 / loss: 0.114320 / acc: 0.95312
step: 670 / loss: 0.094687 / acc: 0.94531
step: 680 / loss: 0.115308 / acc: 0.96875
step: 690 / loss: 0.125207 / acc: 0.96094
step: 700 / loss: 0.085296 / acc: 0.96875
step: 710 / loss: 0.119154 / acc: 0.94531
step: 720 / loss: 0.089058 / acc: 0.96875
step: 730 / loss: 0.054484 / acc: 0.97656
step: 740 / loss: 0.113646 / acc: 0.93750
step: 750 / loss: 0.051113 / acc: 0.99219
step: 760 / loss: 0.183365 / acc: 0.94531
step: 770 / loss: 0.112222 / acc: 0.95312
step: 780 / loss: 0.078913 / acc: 0.96094
Test Accuracy: 0.984375



