LoginSignup
55
59

More than 5 years have passed since last update.

機械学習ガチ素人の俺が無謀にもTensorFlow MNIST チュートリアルの内容を可視化して解説

Last updated at Posted at 2016-06-16

はじめに

この記事は、機械学習の素人であり、数式を見ると目が泳いでしまうような自分が、TensorFlowのチュートリアルが一体なにをやっているのかを、試行錯誤しつつ解明しようとするものです。

はっきり言って記事中で語られる知識のレベルはだいぶ低いので、ディープラーニングやTensorFlowについての有益な情報を得ることはできないと思います :fearful:

ただ自分と同じように、TensorFlowを触ってみたはいいけどなにがなにやら分からないと感じた方たちが、この記事を見て「なるほど、そういうことだったのか!」などと思ってもらえれば幸いです。

なお、素人なりの解釈に基いて書いている部分が多いので、誤って理解している箇所があるかもしれません。その場合は(優しく)指摘してもらえると助かります! :pray:

記事を書いた経緯

  1. TensorFlowを使ったらなにか面白いことができるんじゃないかと期待してやってみた
  2. 数字認識を行うMNISTチュートリアルを動かしてはみたものの、何をやっているのか、なぜこれでうまくいくのか、さっぱり分からない :scream:
  3. のっけからつまづいてしまう自分の知識のなさに失望
  4. 悔しかったので頑張って内容を理解しようとしてみる
  5. こういうことかな?と自分なりに解釈したので記事にしてみる

とりあえずプログラム本体

よく見かけるやつです。
階層を深くしたバージョンのほうが精度は高くなるのですが、自分が理解できないので説明のため簡単なほうを利用します。

mnist_tutorial.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
import sys

mnist = read_data_sets("MNIST_data/", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784], name="x")
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

W = tf.Variable(tf.zeros([784, 10]), name="weights")
b = tf.Variable(tf.zeros([10]), name="bias")
y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
steps = int(sys.argv[1]) if len(sys.argv) >= 2 else 1000
batch_size = int(sys.argv[2]) if len(sys.argv) >= 3 else 100
for i in range(steps):
    batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print("step=%d, accuracy=%g" % (i, sess.run(accuracy, feed_dict={ x: mnist.test.images, y_: mnist.test.labels })))

動かしてみる

$ python -i mnist_tutorial.py 1000 100
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
step=0, accuracy=0.4075
step=100, accuracy=0.8948
step=200, accuracy=0.9031
step=300, accuracy=0.9074
step=400, accuracy=0.9037
step=500, accuracy=0.9125
step=600, accuracy=0.914
step=700, accuracy=0.9151
step=800, accuracy=0.9194
step=900, accuracy=0.9189
>>> 

解説の前に

もしあなたが上記コードをざっと読んでみて、「ははあなるほど、こういうふうに計算しているのね」と理解できてしまうような方であれば、ここにあなたの求める情報はありません。すいませんその程度の記事なんです。

解説

そもそもどういうデータを扱おうとしているのか

画像を読み込んで数字を認識するわけなので、当然画像データが含まれます。また個々の画像について正解の数字が付与されている必要があるので、そのデータも含まれます。

で、結局以下のようなデータを扱うことになります。

>>> mnist.test.images
array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)
>>> mnist.test.labels
array([[ 0.,  0.,  0., ...,  1.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.],
       [ 0.,  1.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]])

これだけ見てもさっぱりですよね。もう少し分かりやすくしてみます。

画像データを可視化

  • mnist.test.imagesは10000x784のマトリックスになっています
  • 10000というのはデータ数のことで、結局10000件の検証データが含まれているということになります
  • 784というのは、28x28からきており、これは画像の画素数を意味しています
    • つまり個々の画像は28x28ピクセルで構成されます
    • 784の要素は0以上1未満の少数値であり、各ピクセルの色の濃さを表現しています
    • 画像はグレースケールです

というわけで、特定の1つの画像に着目してみます。

>>> mnist.test.images[0]                                                                                                                                                                  
array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
-snip-
        0.47450984,  0.99607849,  0.99607849,  0.8588236 ,  0.15686275,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.47450984,  0.99607849,
        0.81176478,  0.07058824,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
-snip-
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32)

やっぱりよく分かりませんね。
ではこれを28x28のグレースケール画像として表示してみます。

>>> plt.imshow(mnist.test.images[0].reshape(28, 28), cmap=cm.gray_r)
<matplotlib.image.AxesImage object at 0x12177f780>
>>> plt.show()

tf_image_digit.png

ようやく理解できる形式になりました。この画像データを784次元のベクトルとして扱っているわけですね。

ラベルデータの可視化

  • mnist.test.labelsは、10000x10のマトリックスになっています
  • 10000というのはデータ数のことで、結局10000件の検証データが含まれているということになります
  • 10というのは、画像が0~9の範囲のどれであるか、を意味しています

これも可視化したほうが分かりやすいです。

>>> plt.plot(mnist.test.labels[0])
[<matplotlib.lines.Line2D object at 0x12a952ef0>]
>>> plt.show()

tf_image_label.png

つまり、この画像データの正しい数字は7なので、10次元ベクトルの7番目が1.0、それ以外は0.0になっているデータとして表現している、ということです。
ちなみに、こういうデータのことをOne-Hotベクトルというみたいです。

Weights&Bias?

結局のところ、どんな計算をすることで正しい画像認識をしているのでしょうか?
おそらく、キモとなる計算は以下のあたりです。
(というか、それ以外に計算式がほとんど出てきません :sweat_smile:

y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

tf.matmulというのは、要するに行列同士の積を計算しているのでしょう。つまりx * W + bをやっていると。

xというのは複数個の画像データを表現しているので、100x784次元のベクトルです。100というのは単に、訓練用データを100件ずつ読み込んでいるというだけです。

Wというのはweightsで、784x10次元のベクトルです。784ってことは、例の28x28画像と関係ありそうです。

これだけじゃなにがなにやらなので、とりあえず学習済のWを可視化してみましょう。同様に28x28の画像として表示してみます。

_W = sess.run(W)
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.title("W_%d" % i)
    plt.axis("off")
    plt.imshow(_W.transpose()[i].reshape(28, 28), cmap=None)
plt.show()


tf_image_heat_map.png

んん?

これを見た時、直感的にWというのは0~9の各画像における「らしさ」を表現しているんじゃないか?と感じました。
赤い部分がその数字における「ポジティブならしさ」、青い部分がその数字における「ネガティブならしさ」ということです。


ということは、認識対象の画像データに対してWを演算して、10個の数字それぞれについて、「その数字である確からしさ」を求めているのではないか? と考えました。

数式の整理と可視化

計算式部分を再掲します。
なお、xというのは画像データ、y_というのは正解ラベルのデータで、訓練時や検証時に与えられるものです。

y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

上記は以下の式から構成されています。

  • x * W + b
  • y = softmax(x * W + b)
  • y_ * log(y)
  • reduce_sum(y_ * log(y))
  • reduce_mean(reduce_sum(y_ * log(y)))

なにをやっているのかさっぱりですが、例によってとりあえず可視化。
なお分かりやすくするため、検証用データは1件のみにしています。(test_images, test_labels)

test_images = mnist.test.images[0:1]
test_labels = mnist.test.labels[0:1]
feed_dict = { x: test_images, y_: test_labels }

_b = sess.run(b)
_W = sess.run(W)
_x_W_b = sess.run(tf.matmul(x, _W) + _b, feed_dict=feed_dict)
_y = sess.run(tf.nn.softmax(_x_W_b), feed_dict=feed_dict)
_log_y = sess.run(tf.log(_y), feed_dict=feed_dict)
_y__log_y = sess.run(y_ * _log_y, feed_dict=feed_dict)
_sum_y__log_y = sess.run(-tf.reduce_sum(_y__log_y, reduction_indices=[1]), feed_dict=feed_dict)

plt.subplot(1, 6, 1)
plt.title("x")
plt.imshow(test_images[0].reshape(28, 28), cmap=cm.gray_r)

plt.subplot(1, 6, 2)
plt.title("y'")
plt.plot(test_labels[0])

plt.subplot(1, 6, 3)
plt.title("x*W+b")
plt.plot(_x_W_b[0])

plt.subplot(1, 6, 4)
plt.title("y=softmax(x*W+b)")
plt.plot(_y[0])

plt.subplot(1, 6, 5)
plt.title("log(y)")
plt.plot(_log_y[0])

plt.subplot(1, 6, 6)
plt.title("y'*log(y)")
plt.plot(_y__log_y[0])

plt.show()

tf_image_results.png

xy'は言うまでもなく検証用の画像データとラベルデータで、数字の「4」であることが分かります。

x*W+bを見ると、正解である4が一番高くなっているデータになっています。次点で9が高くなっているのは、4と9は似ているからだろうと思います。
つまり、4の画像データにweightsとbiasを加味すると、4である可能性が高いことを示唆していることになります。どうやら先ほどの予想は正しかったようです。

y = softmax(x*W+b)ってのは正直よく分かってないのですが、ソフトマックス関数というものを適用して、1番数字の大きいものは1.0に近く、それ以外は0.0に近くする、ということをやっているようです。メリハリをつけているとでもゆーか :smile: 参考

log(y)をすることで、ソフトマックス関数を適用したものを元のx*W+bに近いものに戻しています。グラフの形自体は同じですが、y軸の値が変わっていますね。正解である4が0に近い負数で、それ以外は大きな負数になっています。

y'*log(y)は、正解である4の数値のみを残して、あとはゼロにしています。

う〜ん、やっていることは分かりましたが、なぜこれで正しく学習することができるのかは未だに分かりません・・・

途中経過の可視化

ソフトマックス関数を適用したり、log関数を適用したりしているのがなぜかよく分かりませんでした。
得られるグラフの形があんまり変わっていないからです。

で思ったのが、ちゃんと学習したデータだけではなく、学習途中のデータも同じように可視化してみたらどうなるだろうか? ということ。なのでやってみました。

for i in range(5):
    idx = 0 + i
    test_images = mnist.test.images[idx:idx+1]
    test_labels = mnist.test.labels[idx:idx+1]
    feed_dict = { x: test_images, y_: test_labels }

    plt.figure(1, figsize=(16, 16), dpi=100)

    plt.subplot(5, 5, i+1)
    plt.axis("off")
    plt.imshow(test_images[0].reshape(28, 28), cmap=cm.gray_r)

    plt.subplot(5, 5, i+6)
    plt.title("y")
    plt.plot(sess.run(y, feed_dict=feed_dict)[0])

    plt.subplot(5, 5, i+11)
    plt.title("log(y)")
    plt.axis([0, 9, -30.0, 0])
    plt.plot(sess.run(tf.log(y), feed_dict=feed_dict)[0])

    plt.subplot(5, 5, i+16)
    plt.title("y'")
    plt.plot(test_labels[0])

    plt.subplot(5, 5, i+21)
    plt.title("y'*log(y)")
    plt.axis([0, 9, -2.0, 0])
    plt.plot(sess.run(y_ * tf.log(y), feed_dict=feed_dict)[0])
plt.show()

まずは正しく学習したのデータに対しての可視化。

$ python mnist_tutorial.py 1000 100


tf_image_tests_success.png

いずれも正しく認識していますね。
y'*log(y)についても、特に違いは出ていません。

では、学習が十分でないデータに対しての可視化。

$ python mnist_tutorial.py 10 10


tf_image_tests_failure.png

んん?

学習が不十分だけあって4に対する認識が間違っていますが、その際y'*log(y)の値が、他に比べてとても大きな負数になっています。
どうやら、間違った学習をした場合には大きな負数が得られるような計算をしていたようです。

コスト関数はなにをやっているか

以下のコードで学習をしているわけですが、

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

cross_entropyというのはコスト関数が算出した値で、正解した場合には小さく、間違った場合には大きくなります。これを小さくしていく方向に、自動的にweightsとbiasを調整していくわけですね。

先ほどの例でいうと、y'*log(y)の値について、正解していたケースではいずれも小さかったですが、間違っていたケースでは1つだけ大きなものが含まれていました。
これらをreduce_meanで平均化しているので、結果として間違ったほうが大きな値が得られることになります。

まとめ

ようやく、これらの式を記述するとなぜ学習ができるのかが分かりました。
とは言え、以下のようなことは未だに把握できておらず、自分の理解度の低さを露呈しています :skull:

  • 訓練中、WeightsとBiasはどのようにして調整されていくのか?
  • なぜ層を深くすると学習精度が上がるのか?

たぶん魔法です

そもそも、動作自体はどうにかこうにか把握できたものの、これを自分で考えてやってみろと言われてできる気がしません。
自分のやりたいことを数式に置き換えて考えるという土台がないので、まずそこからなんだろうなあと思いました。

可視化のメリット

ただ、データ内容を可視化することで分かってくることがあるということを改めて認識しました。

おそらく上記で解説しているような内容はチュートリアル内のテキストにも書いてあるんでしょうし、自分も読みはしましたが、それだけではよく分からなかったのが正直なところでした。

そこでWeightsを可視化した時に初めて、「おや? これならなんとなく分かるかもしれないぞ・・・」と思ったのでした。

というわけで、
可視化は重要 :exclamation: (俺レベルだと特に)

55
59
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
59