Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
143
Help us understand the problem. What is going on with this article?
@uramonk

TensorFlow MNIST For ML Beginners チュートリアルの実施

More than 3 years have passed since last update.

はじめに

前回ではTensorFlowのチュートリアルであるMNIST For ML Beginnersの翻訳をしたので、今回は実際にTensorFlowを用いてチュートリアルの内容を実装してみました。
と言いましても、チュートリアルの中ですでにコードがぱらぱらとではありますがすべて出てきていますので、そのコードが何を意味しているのかを理解しながらの実装となります。

環境セットアップ

私自身、Python使うのが初めてですので、同じような初心者の方でも実行できるようにセットアップについても記述します。
基本的にはTensorFlowのサイトにある内容と同じです。
対象はmacとします。

Pythonのインストール

まずはPython自体のインストールです。

brew install python

そしてPythonのパッケージ管理システムであるpipもインストールします。

sudo easy_install pip

TensorFlowではvirtualenvというPythonの仮想環境上での実行が推奨されているようですので、pipを使ってインストールします。

sudo pip install --upgrade virtualenv

Python関連のインストールは以上です。

TensorFlowのインストール

TensorFlowをインストールする前に、virtualenvの環境設定を行います。
--system-site-packageの意味がよくわかりませんが、とりあえず公式サイトに従って実行します。

virtualenv --system-site-packages ./tensorflow

環境設定が完了したらvirtualenvを実行します。

cd tensorflow
source bin/activate

TensorFlowのインストールをします。
これでセットアップは完了です。

sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.8.0-py2-none-any.whl

MNIST For ML Beginnersの実行

さて、ここからが本番の訓練の実装および実行となります。

MNIST For ML Beginnersの実装コード

コード全文は下記のとおりです。
それぞれの説明についてはコード内のコメントにほぼすべて記述してしまっていますので、そちらを見てみてください。

Gistにもコードをあげておきました。
TensorFlow MNIST For ML Beginners チュートリアルのコード

mnist_for_ml_beginners.py
# -*- coding: utf-8 -*-

# TensowFlowのインポート
import tensorflow as tf
# MNISTを読み込むためinput_data.pyを同じディレクトリに置きインポートする
# input_data.pyはチュートリアル内にリンクがあるのでそこから取得する
# https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/input_data.py
import input_data

import time

# 開始時刻
start_time = time.time()
print "開始時刻: " + str(start_time)

# MNISTデータの読み込み
# 60000点の訓練データ(mnist.train)と10000点のテストデータ(mnist.test)がある
# 訓練データとテストデータにはそれぞれ0-9の画像とそれに対応するラベル(0-9)がある
# 画像は28x28px(=784)のサイズ
# mnist.train.imagesは[60000, 784]の配列であり、mnist.train.lablesは[60000, 10]の配列
# lablesの配列は、対応するimagesの画像が3の数字であるならば、[0,0,0,1,0,0,0,0,0,0]となっている
# mnist.test.imagesは[10000, 784]の配列であり、mnist.test.lablesは[10000, 10]の配列
print "--- MNISTデータの読み込み開始 ---"
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print "--- MNISTデータの読み込み完了 ---"

# 訓練画像を入れる変数
# 訓練画像は28x28pxであり、これらを1行784列のベクトルに並び替え格納する
# Noneとなっているのは訓練画像がいくつでも入れられるようにするため
x = tf.placeholder(tf.float32, [None, 784])

# 重み
# 訓練画像のpx数の行、ラベル(0-9の数字の個数)数の列の行列
# 初期値として0を入れておく
W = tf.Variable(tf.zeros([784, 10]))

# バイアス
# ラベル数の列の行列
# 初期値として0を入れておく
b = tf.Variable(tf.zeros([10]))

# ソフトマックス回帰を実行
# yは入力x(画像)に対しそれがある数字である確率の分布
# matmul関数で行列xとWの掛け算を行った後、bを加算する。
# yは[1, 10]の行列
y = tf.nn.softmax(tf.matmul(x, W) + b)

# 交差エントロピー
# y_は正解データのラベル
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

# 勾配硬化法を用い交差エントロピーが最小となるようyを最適化する
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 用意した変数Veriableの初期化を実行する
init = tf.initialize_all_variables()

# Sessionを開始する
# runすることで初めて実行開始される(run(init)しないとinitが実行されない)

sess = tf.Session()
sess.run(init)

# 1000回の訓練(train_step)を実行する
# next_batch(100)で100つのランダムな訓練セット(画像と対応するラベル)を選択する
# 訓練データは60000点あるので全て使いたいところだが費用つまり時間がかかるのでランダムな100つを使う
# 100つでも同じような結果を得ることができる
# feed_dictでplaceholderに値を入力することができる
print "--- 訓練開始 ---"
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})
print "--- 訓練終了 ---"

# 正しいかの予測
# 計算された画像がどの数字であるかの予測yと正解ラベルy_を比較する
# 同じ値であればTrueが返される
# argmaxは配列の中で一番値の大きい箇所のindexが返される
# 一番値が大きいindexということは、それがその数字である確率が一番大きいということ
# Trueが返ってくるということは訓練した結果と回答が同じということ
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

# 精度の計算
# correct_predictionはbooleanなのでfloatにキャストし、平均値を計算する
# Trueならば1、Falseならば0に変換される
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 精度の実行と表示
# テストデータの画像とラベルで精度を確認する
# ソフトマックス回帰によってWとbの値が計算されているので、xを入力することでyが計算できる
print "精度"
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

# 終了時刻
end_time = time.time()
print "終了時刻: " + str(end_time)
print "かかった時間: " + str(end_time - start_time)

tf.argmaxの2つ目の引数は次元数を指定するようですが、yやy_の配列は1行10列でその中から1つを取ってくるので、1(次元)を指定しているのでしょうか。

実行結果

実行自体は非常に早く3秒弱で訓練から確認までできてしまいました。
最初に実行した時はMNISTデータをダウンロードしてくるのに時間がかかり5分ほどかかりました。

実行方法および出力結果は下記です。

$ python mnist_for_ml_beginners.py
開始時刻: 1449994007.63
--- MNISTデータの読み込み開始 ---
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
--- MNISTデータの読み込み完了 ---
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4
can't determine number of CPU cores: assuming 4
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4
--- 訓練開始 ---
--- 訓練終了 ---
精度
0.9212
終了時刻: 1449994010.09
かかった時間: 2.45791196823

精度として92%強となっていることがわかります。
入力した画像で結果は何が出力されたなど出していないので、あくまで数値だけの結果ではあります。
この精度は実行するごとに異なります。訓練時に使っている訓練データが異なるからですね。

終了後はvirtualenvをdeactivateして終了します。

deactivate

おわりに

前回チュートリアルを翻訳するときに内容もある程度確認していましたので、今回実装するときもこのコードがこうなっているんだ、など理解しながらすすめることができました。
まだまだソフトマックス回帰や交差エントロピーなどのところは理解していないので、この辺りも勉強して言ったほうが良さそうですね。
ソフトマックス回帰は名前の通り回帰分析でもしているのだろうなとは思います。

今回はチュートリアルをそのまま実行してみたので、画像も用意されているものをそのまま使ったのみですが、別の画像などでも試してみたいですね。
しかしその場合は画像サイズをあわせたり、正規化したり、input_data.pyの内容を見て理解したりなどが必要だと思いますので、もうしばらく掛かりそうなので次はエキスパート向けのチュートリアルをやってみたいです。
こちらが本番?であるDeep Learningのようです。

セットアップに関する参考サイト

TensorFlow Download and Setup
TensorFlowで Hello Worldを動かしてみた&その解説

143
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
143
Help us understand the problem. What is going on with this article?