TensorFlowのチュートリアル(MNIST Data Download)
https://www.tensorflow.org/versions/master/tutorials/mnist/download/index.html#mnist-data-download
の翻訳です。
翻訳の誤りなどあればご指摘お待ちしております。
コード: tensorflow/examples/tutorials/mnist/
このチュートリアルの目的は、(古典的)MNISTデータセットを使用した手書き数字の分類に必要なデータセット・ファイルをダウンロードする方法を示すことです。
チュートリアル・ファイル
このチュートリアルでは、以下のファイルを参照します:
ファイル | 目的 |
---|---|
input_data.py | 訓練と評価のためにMNISTデータセットをダウンロードするコード。 |
データの準備
MNISTは、機械学習における古典的な問題です。この問題は、手書き数字のグレースケールの28x28ピクセルの画像を見て、画像が表す数字が0から9までのうちどの数字かを決定する、というものです。
詳細については、 Yann LeCun のMNISTページ または Chris Olah のMNISTの視覚化 を参照してください。
ダウンロード
Yann LeCun のMNISTのページ は、訓練データとテスト・データをダウンロードのためにホストしています。
ファイル | 目的 |
---|---|
train-images-idx3-ubyte.gz | 訓練セット画像 - 55000訓練画像、5000検証画像 |
train-labels-idx1-ubyte.gz | 画像に対応する訓練セット・ラベル |
t10k-images-idx3-ubyte.gz | テスト・セット画像 - 10000画像 |
t10k-labels-idx1-ubyte.gz | 画像に対応するテスト・セット・ラベル |
input_data.py ファイルの、 maybe_download() 関数は、これらのファイルが訓練のためのローカル・データ・フォルダにダウンロードされていることを保証します。
フォルダ名は fully_connected_feed.py ファイルの先頭にフラグ変数で指定されており、必要に応じて変更することができます。
解凍とリシェイプ
ファイル自体は、標準的な画像形式ではなく、 input_data.py で extract_images() と extract_labels() 関数により手動で(ウェブサイトの指示に従い)解凍されます。
画像データは、 [画像インデックス, ピクセル・インデックス] の2Dテンソルに抽出されます。ここで、各要素はインデックスが示す画像内のピクセルの輝度の値であり、 [0, 255] から [-0.5, 0.5] に再スケールされています。「画像インデックス」は、データセット内の画像に対応し、0からデータセットのサイズまでカウントアップされます。そして「ピクセル・インデックス」は、画像内の特定のピクセルに対応し、0から画像のピクセル数までの範囲を取ります。
train-* ファイル内の 60000 サンプルは、その後、訓練のための 55000 サンプルと検証のための 5000 サンプルに分割されます。データセット内の全ての 28x28 ピクセルのグレースケール画像の画像サイズは 784 ですので、訓練セット画像の出力テンソルの形状は [55000, 784] です。
ラベル・データは、各サンプルのクラス識別子を値とする、 [画像インデックス] の1Dテンソルに抽出されます。従って、訓練セット・ラベルの場合、テンソルの形状は [55000] です。
データセット・オブジェクト
基礎をなすコードは、以下のデータセットについて、画像とラベルをダウンロードし、解凍し、形状を変更します:
データセット | 目的 |
---|---|
data_sets.train | 主な訓練のための 55000 画像とラベル。 |
data_sets.validation | 訓練精度の反復的検証のための 5000 画像とラベル。 |
data_sets.test | 訓練済み精度の最終テストのための 10000 画像とラベル。 |
read_data_sets() 関数は、これら3つのデータ集合それぞれの DataSet インスタンスを保持する辞書を返します。 DataSet.next_batch() メソッドは、 batch_size サイズの画像およびラベルのリストからなるタプルをフェッチするために使用します。これらのリストは実行中のTensorFlowセッションにフィードされます。
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)