3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowのチュートリアル(MNIST)を読む

Last updated at Posted at 2016-10-04

Googleが公開している機械学習のライブラリ「TensorFlow」を使って、何かできないかと思い勉強を開始。
まずは一番オーソドックスなチュートリアル「MNIST」から見ていきましょう。
まだ理解しきれていないところがあるので、ツッコミ歓迎です

対象

TensorFlow r0.10
機械学習種別 教師あり機械学習
ディープラーニングのアルゴリズム FFNN(Feedforward Neural Network)
ソースファイル fully_connected_feed.py
mnist.py
input_data.py
mnist.py(input_data.pyからimportされているデータセット展開用)

MNISTとは

MNISTとは28x28ピクセルの手書き数字のデータセットを指している。
それを読み込み機械学習した後に別の画像に記載されている数字が何であるかを判別させるのが、このチュートリアルの目的となる。

コードの説明

main

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
def main(_):
  run_training()


if __name__ == '__main__':
tf.app.run()

↑いわゆるmain関数です。ここからスタート。

run_training

これが実際の処理関数。
処理の流れとしては以下のような感じになる。

  1. MNISTのダウンロードとメモリ上への展開
  2. 機械学習を実施するためのop、tensorの作成
  3. opの実行
  4. 学習時に使用したデータとは別の画像(ただし、データ自体はMNISTで提供されている)を機械学習によって作成された学習器で解析

TensorFlowはGPUを使用することも設計思想に入っているため、処理はop(operationの略)、それの算出結果をop実行後に保持するための変数をtensorとして作成し、Session.runメソッドにて処理をGPU(GPUが存在しない場合はそのままCPU)へ処理を依頼しないと動作しません(この辺は同じくGPUを使用するOpenGLと同じような設計思想になっているようです)。

run_training - MNISTのダウンロードとメモリ上への展開

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
def run_training():
  """Train MNIST for a number of steps."""
  # データセットMNISTをDLしてメモリ上に展開
  # data_setsの中身は大きく分けて3種類に分かれる
  #  data_sets.train: 機械学習用の画像データ配列
  #  data_sets.valifation: 
  #  data_sets.test: data_sets.trainを使用して学習した結果(学習器)で実際にどれぐらい判別可能かを試験するための画像データ
  data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

このinput_data.read_data_setsにてMNISTのデータファイルのダウンロードとdata_setsの構築を実施。
=> data_setsの中に学習用、検証用のデータを格納したDataSetインスタンスが設定されてくる。
また、それぞれのDataSetインスタンスに、28x28の手書き数字画像の3次元配列とそれらがどの数字なのかを示すラベルの1次元配列が設定されている。

run_training - 機械学習を実施するためのopの作成

この後から、機械学習を実施するためのopの構築が行われる。この辺からopやplaceholderといった実際に処理を実施(Session.run)するまで動作しないものを作成していく。なのでデバッガで値を見ても特に得るものがない感じです。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

一回に処理するデータ数"batch_size(100)"を格納できるplaceholderを手書き数字画像とラベル分用意。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
FLAGS.hidden2)

mnist.inferenceにより、batch_size数分の画像データがそれぞれ、どの数字である確率が高いかをtensorで返却される。
このtensor"logits"は名称のとおり、logit(p)の複数の値が配列として返却されている。
Wikipediaに説明されていることから"logit(p)"のpが確率を示しているとのこと。
実行されると以下のようなデータが格納される。

出力結果例:
logits = {
    // 1番目の画像は「8」である確率が高い
    {-3.57313514 -2.2600534  2.13897038  0.33677426 -1.37029171 0.64949733  1.49645376 -6.81540442  4.00773335 -3.18380189},

    // 2番目の画像は「1」である確率が高い
    {-2.4933331   5.71597528  1.00439227  0.79327136 -3.49611211 -0.2769528 -1.74219    -0.13075426  0.51126349 -1.12805915},
                       :
                       : batch_size分、存在
                       :
}

このままだと対数なので、いまいち分かり難いので対数を解除(*)すると、以下のような値になる。

  • 解除は「p = 10^X / (1+10^X)」(Xはlogitの値)を実施
出力結果例を対数解除した値:
p = {
    // 1番目の画像は「8」である確率が100%、それ以外に「6」である確率が96.9%
    {0.000, 0.005, 0.993, 0.685, 0.041, 0.817, 0.969, 0.000, 1.000, 0.001},

    // 2番目の画像は「1」である確率が100%、それ以外に「2」である確率が91%
    {0.003, 1.000, 0.910, 0.861, 0.000, 0.346, 0.018, 0.425, 0.764, 0.069},
                       :
                       : batch_size分、存在
                       :
}

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

mnist.lossにて先のmnist.inferenceにて算出されるlogistsと教師データを格納するlabels_placeholderの差を算出する。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

mnist.trainingにて先に算出したlossと学習率から学習用のopを作成。これを実行することが学習が実施される。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

logitsに格納されている各画像データがどの数字であるかを示す確率とそれぞれの教師データを比較して、実際の正解した画像イメージ数が格納されるtensorが返却される。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

TensorBoard向けのSummaryをマージするopを作成。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Add the variable initializer Op.
    init = tf.initialize_all_variables()

これまでに用意したVariableを全て初期化するためのopの作成。
opは作成したときには実行されず、Session.runを実施したときに、実行されるのでここで作成しても、これまでにVariableに設定された値が初期化されることはない。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

学習結果を保存するopを作成。

run_training - opの実行

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Create a session for running Ops on the Graph.
    sess = tf.Session()

ここでようやくopを実行するためのSessionを作成。ここからこれまでに作成したopを使用することでGPUにopの内容を渡してGPUで処理させていく(GPU非搭載マシンの場合はそのままCPUで実行される)。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

Session開始後、最初に"init"opを実行することで、これから使用されるVariableを全て一旦初期化。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

fill_feed_dictにて、最初にメモリに展開したdata_setsの内、学習用の手書き数字画像データと教師データ用のラベルのbatch_size(100)分のデータをそれぞれimages_placeholder、labels_placeholderをキーとしたdictionary化しfeed_dictを設定。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

train_op(mnist.inference、mnist.loss、mnist.trainingを実行するop)とmnist.lossで算出されるloss値が格納されるtensorを実行し、その結果を返却(train_opは戻り値がないため、"_"にて対応)。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

学習しながら、その結果をTensorBoardで表示できるようにループ100回中1回のみ先に作成したsummary_opを実行して文字列を作成し、それをTensorBoardに書き込む。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:

以下に続く処理をループ1000回中1回、または最後のループ時に実施。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
        saver.save(sess, checkpoint_file, global_step=step)

チェックポイントとして、これまでに学習した結果をファイルに保存。

run_training - 学習時に使用したデータとは別の画像(ただし、データ自体はMNISTで提供されている)を機械学習によって作成された学習器で解析。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)

do_evalにて学習用の手書き数字画像データ全てをこの時点での学習結果で認識させて、どれぐらいの正解率になるかをチェック。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)

do_evalにて検証用の手書き数字画像データ全てをこの時点での学習結果で認識させて、どれぐらいの正解率になるかをチェック。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)

do_evalにて試験用の手書き数字画像データ全てをこの時点での学習結果で認識させて、どれぐらいの正解率になるかをチェック。

fill_feed_dict

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.
  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }
  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().
  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)

DataSet.next_batchでDataSet._index_in_epochをbatch_size分インクリメントしているので、DataSet.next_batchをコールするたびに開始位置が更新され、次のデータが返却される。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict

返却用のデータをdictionary化して返却。

do_eval

機械学習によって成長(?)した学習器でどれぐらい手書き数字画像を正しく認識するかを試験する関数(eval = evaluation: 評価)。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """Runs one evaluation against the full epoch of data.
  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size

steps_per_epochにdata_set.num_examples(学習用: 55,000、検証用: 5,000、試験用: 10,000)をbatch_size(100)で切り捨て除算した結果を格納(C言語、Javaエンジニア的には"//"はコメント開始なので間違えそうですが)
=>
steps_per_epochには、この後に実行するループの上限回数が設定される。

またnum_examplesには実際に使用したデータ数が設定される。
steps_per_epochを算出するとき、batch_sizeにて切り捨てしているのでnum_examples <= data_set.examplesになる

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)

検証に使用するためのデータをdictionary化。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
    true_count += sess.run(eval_correct, feed_dict=feed_dict)

"eval_correct"opを実行することで、実質渡されたfeed_dictを用いてmnist_inferenceが実行され、その結果が格納されているlogitsに対してmnist.evaluationが実行される。
そのmnist.evaluationでlogitsと教師データの比較を行い、正解したデータ数が返却されるので、それを変数"true_count"に加算していく。

tensorflow/examples/tutorials/mnist/fully_connected_feed.py
  precision = true_count / num_examples
  print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))

true_countをnum_examples(実際に試験で使用したデータ数)で除算することで正解率が算出されるので、それを表示。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?