Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

This article is a Private article. Only a writer and users who know the URL can access it.
Please change open range to public in publish setting if you want to share this article with other users.

More than 5 years have passed since last update.

TensorFlowのチュートリアル(MNIST)を読む - read_data_sets

Last updated at Posted at 2016-09-19

read_data_sets

MNISTのデータをダウンロード+メモリへの展開を実施する関数。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True):
  if fake_data:

    def fake():
      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

ここまでは実際のMNISTのデータを用いずに、擬似データを使用する場合の処理。なので、割愛。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
  VALIDATION_SIZE = 5000

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   SOURCE_URL + TRAIN_IMAGES)
  train_images = extract_images(local_file)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   SOURCE_URL + TRAIN_LABELS)
  train_labels = extract_labels(local_file, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   SOURCE_URL + TEST_IMAGES)
  test_images = extract_images(local_file)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   SOURCE_URL + TEST_LABELS)
  test_labels = extract_labels(local_file, one_hot=one_hot)

base.maybe_downloadにてデータをダウンロードし、引数"train_dir"で指定されたディレクトリに、それぞれ受信したいファイルの名称でファイルを保存。
受信したファイルが手書き文字画像の場合は、extract_imagesで展開。
また受信したファイルがラベルデータの場合は、extract_labelsで展開。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
  validation_images = train_images[:VALIDATION_SIZE]
  validation_labels = train_labels[:VALIDATION_SIZE]
  train_images = train_images[VALIDATION_SIZE:]
  train_labels = train_labels[VALIDATION_SIZE:]

読み込んだデータを学習用(60,000 - 5,000 = 5,5000)と、学習結果の検証用(5,000)に分ける。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
  train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
  validation = DataSet(validation_images,
                       validation_labels,
                       dtype=dtype,
                       reshape=reshape)
  test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

  return base.Datasets(train=train, validation=validation, test=test)

学習用と検証用、それぞれの手書き数字画像の配列とそこに何が記載されているかを示すラベルの配列をそれぞれのDataSetインスタンスとして纏める。
DataSetインスタンスをさらに纏めたDatasetsインスタンスを構築し、返却。

extract_images

tensorflow/contrib/learn/python/learn/datasets/mnist.py
def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting', filename)
  with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:

受信したgzipで圧縮されているファイルを解凍。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                       (magic, filename))

データが不正ではないか確認するために、仕込まれている先頭の32ビットがマジックナンバー"2051(0x00000803)"と一致するか確認。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)

ファイルに格納されているイメージ数と、それぞれのイメージの縦横のピクセル数を取得。
参考までに、自分が取得したデータはnum_images = 60000、rows = cols = 28となっていました。これにより28x28の手書き数字画像が60,000個格納されていることを示します。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

一旦「buf = bytestream.read(rows * cols * num_images)」で画像データ全体(28x28のデータを60,000個分)を読み込みます。この段階では1次元配列になっているので、reshapeにて60,000個の28x28の3次元配列に変換し、それを返却。

extract_labels

extract_imagesで、手書き数字画像データを取得し、メモリ上に展開していますが、あのデータは画像データしかないため、どの数字が書かれているか分かりません。何が書かれているかを示してくれるのが、extract_labelsで展開するデータになります。
なので、label == 教師データとなります。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
def extract_labels(filename, one_hot=False, num_classes=10):
  """Extract the labels into a 1D uint8 numpy array [index]."""
  print('Extracting', filename)
  with gfile.Open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                       (magic, filename))

extract_imagesとマジックナンバーの確認までは流れは一緒です。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
    num_items = _read32(bytestream)

labelが含まれている数を取得。

tensorflow/contrib/learn/python/learn/datasets/mnist.py
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
      return dense_to_one_hot(labels, num_classes)
    return labels

先ほど取得したlabelが含まれている数分のbyteを読み込み、uint8の配列に置き換えて返却(one_hotはfalseでコールされているのでdense_to_one_hotは割愛)

0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?