MNISTをRNN(Recurrent Neural Networks)でやってみました。
RNN
入力値と出力値に可変長のシーケンシャルデータを扱うことができるネットワーク構造をしている。RNNには状態があり、各時点tにおいて入力値と状態に基いて次の状態に遷移させることができる。RNNは内部に状態を持ち、入力から次の状態へと遷移させることで状態を保持する。
LSTM
LSTM(Long short-term memory)は、RNN(Recurrent Neural Network)の拡張として1995年に登場した、時系列データ(sequential data)に対するモデル、あるいは構造(architecture)の1種。
この記事によると、文章生成なら「今までの単語列を入力として、もっともらしい次の単語を予測する」ことを担う。正しい文章を繰り返しLSTMに覚えさせる(重みベクトルを更新する)ことで、このLSTMは"this"の後に"is"が来るようなルールを「事実上」学習する。的なことができるらしい。なるほど!すげぇ〜!
最も大きな特長は、従来のRNNでは学習できなかった長期依存(long-term dependencies)を学習可能であるところ
とりあえず、やってみよう
このサイトを見てみると以下の画像のように上部から下部に向けて遷移させながらLSTMで学習させるっぽい。
プログラムの流れ
juoyter notebookを使用しました。
必要なライブラリのインポートとMNISTデータを読み込む
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
次に入力と正解ラベルのプレースホルダーの定義
x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])
RNNのモデルを定義する。
128個の隠れ層のユニットを持つLSTMモデル
各ステップ毎に分割されたテンソルに変換する。tf.unstack
で [バッチサイズ × 28] の28個のテンソルを持つPythonのlistに変換。
def RNN(x):
x = tf.unstack(x, 28, 1)
# LSTMの設定
lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)
# モデルの定義。各タイムステップの出力値と状態が返される
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
# 重みとバイアスの設定
weight = tf.Variable(tf.random_normal([128, 10]))
bias = tf.Variable(tf.random_normal([10]))
return tf.matmul(outputs[-1], weight) + bias
コスト関数を定義。今回は、学習させるためにクロスエントロピー誤差関数とAdam Optimizerを使用。
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
# 評価用
correct_pred = tf.equal(tf.argmax(preds, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
作成したモデルを用いて学習させる
batch_size = 128
n_training_iters = 100000
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < n_training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# next_batchで返されるbatch_xは[batch_size, 784]のテンソルなので、batch_size×28×28に変換します。
batch_x = batch_x.reshape((batch_size, 28, 28))
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % 10 == 0:
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print('step: {} / loss: {:.6f} / acc: {:.5f}'.format(step, loss, acc))
step += 1
# テスト
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, 28, 28))
test_label = mnist.test.labels[:test_len]
test_acc = sess.run(accuracy, feed_dict={x: test_data, y: test_label})
print("Test Accuracy: {}".format(test_acc))
step: 10 / loss: 1.751291 / acc: 0.42969
step: 20 / loss: 1.554639 / acc: 0.46875
step: 30 / loss: 1.365595 / acc: 0.57031
step: 40 / loss: 1.176470 / acc: 0.60156
step: 50 / loss: 0.787636 / acc: 0.75781
step: 60 / loss: 0.776735 / acc: 0.75781
step: 70 / loss: 0.586180 / acc: 0.79688
step: 80 / loss: 0.692503 / acc: 0.80469
step: 90 / loss: 0.550008 / acc: 0.82812
step: 100 / loss: 0.553710 / acc: 0.86719
step: 110 / loss: 0.423268 / acc: 0.86719
step: 120 / loss: 0.462931 / acc: 0.82812
step: 130 / loss: 0.365392 / acc: 0.85938
step: 140 / loss: 0.505170 / acc: 0.85938
step: 150 / loss: 0.273539 / acc: 0.91406
step: 160 / loss: 0.322731 / acc: 0.87500
step: 170 / loss: 0.531190 / acc: 0.85156
step: 180 / loss: 0.318869 / acc: 0.90625
step: 190 / loss: 0.351407 / acc: 0.86719
step: 200 / loss: 0.232232 / acc: 0.92188
step: 210 / loss: 0.245849 / acc: 0.92969
step: 220 / loss: 0.312085 / acc: 0.92188
step: 230 / loss: 0.276383 / acc: 0.89844
step: 240 / loss: 0.196890 / acc: 0.94531
step: 250 / loss: 0.221909 / acc: 0.91406
step: 260 / loss: 0.246551 / acc: 0.92969
step: 270 / loss: 0.242577 / acc: 0.92188
step: 280 / loss: 0.165623 / acc: 0.94531
step: 290 / loss: 0.232382 / acc: 0.94531
step: 300 / loss: 0.159169 / acc: 0.92969
step: 310 / loss: 0.229053 / acc: 0.92969
step: 320 / loss: 0.384319 / acc: 0.90625
step: 330 / loss: 0.151922 / acc: 0.93750
step: 340 / loss: 0.153512 / acc: 0.95312
step: 350 / loss: 0.113470 / acc: 0.96094
step: 360 / loss: 0.192841 / acc: 0.93750
step: 370 / loss: 0.169354 / acc: 0.96094
step: 380 / loss: 0.217942 / acc: 0.94531
step: 390 / loss: 0.151771 / acc: 0.95312
step: 400 / loss: 0.139619 / acc: 0.96094
step: 410 / loss: 0.236149 / acc: 0.92969
step: 420 / loss: 0.131790 / acc: 0.94531
step: 430 / loss: 0.172267 / acc: 0.96094
step: 440 / loss: 0.182242 / acc: 0.93750
step: 450 / loss: 0.131859 / acc: 0.94531
step: 460 / loss: 0.216793 / acc: 0.92969
step: 470 / loss: 0.082368 / acc: 0.96875
step: 480 / loss: 0.064672 / acc: 0.96094
step: 490 / loss: 0.119717 / acc: 0.96875
step: 500 / loss: 0.169831 / acc: 0.94531
step: 510 / loss: 0.106913 / acc: 0.98438
step: 520 / loss: 0.073209 / acc: 0.97656
step: 530 / loss: 0.131819 / acc: 0.96875
step: 540 / loss: 0.210754 / acc: 0.94531
step: 550 / loss: 0.141051 / acc: 0.93750
step: 560 / loss: 0.217726 / acc: 0.94531
step: 570 / loss: 0.121927 / acc: 0.96094
step: 580 / loss: 0.130969 / acc: 0.94531
step: 590 / loss: 0.125145 / acc: 0.95312
step: 600 / loss: 0.193178 / acc: 0.95312
step: 610 / loss: 0.114959 / acc: 0.95312
step: 620 / loss: 0.129038 / acc: 0.96094
step: 630 / loss: 0.151445 / acc: 0.95312
step: 640 / loss: 0.120206 / acc: 0.96094
step: 650 / loss: 0.107941 / acc: 0.96875
step: 660 / loss: 0.114320 / acc: 0.95312
step: 670 / loss: 0.094687 / acc: 0.94531
step: 680 / loss: 0.115308 / acc: 0.96875
step: 690 / loss: 0.125207 / acc: 0.96094
step: 700 / loss: 0.085296 / acc: 0.96875
step: 710 / loss: 0.119154 / acc: 0.94531
step: 720 / loss: 0.089058 / acc: 0.96875
step: 730 / loss: 0.054484 / acc: 0.97656
step: 740 / loss: 0.113646 / acc: 0.93750
step: 750 / loss: 0.051113 / acc: 0.99219
step: 760 / loss: 0.183365 / acc: 0.94531
step: 770 / loss: 0.112222 / acc: 0.95312
step: 780 / loss: 0.078913 / acc: 0.96094
Test Accuracy: 0.984375
98%の精度で分類!!