Help us understand the problem. What is going on with this article?

PythonでMNISTをダウンロードして前処理する

More than 1 year has passed since last update.

何をするのか

手書き数字の認識のデータセットとして有名なMNISTのダウンロード、簡単な前処理の方法を紹介します。ダウンロードしたデータは、pickle形式で保存して、すぐにロードできるようにします。

ここで紹介するコードは「ゼロから作る Deep Learning」を参考にしています。ここで、紹介するものは、「ゼロから作るDeep Learning」の本筋ではなくて、付属のコードに関するものです。しかし、読んでいて今後応用できる可能性が大きく、また得るものも多かったのでここで紹介します。

実行環境

  • Windows10
  • Python 3.6.5 (Anaconda)
  • Jupyter Notebook

Pickleとは

上で、pickel形式を太字にしました。これはプログラム実行中のオブジェクトをファイルとして保存するものです。なので、この形式で保存しておいて、後で再びロードすると、保存した時点でのオブジェクトの状態を再現することができます。

ちなみに「pickle」とは漬物の意味で、データを漬物のように長期保存に適した形で保存するということを表しているのだと思います(多分)。

データをダウンロードする

データをホームページからダウンロードします。アドレスは
http://yann.lecun.com/exdb/mnist/
です。ここからファイルをダウンロードするわけですが、ここからPythonを使って行います。

import urllib.request

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

方法としては、urllibパッケージを使ってホームページからデータをとってきて保存します。上のページにアクセスするとわかりますが、4種類のフォルダーがダウンロードできるようになっています。それぞれを'train_img', 'train_label', 'test_img', 'test_label'というキーを付けて区別します。

では、実際にダウンロードします。

dataset_dir = 'C:/Users/usr/Documents'    #データを保存する場所

for v in key_file.values():
    file_path = dataset_dir + '/' + v
    urllib.request.urlretrieve(url_base + v, file_path)

少し待てばダウンロードが終了します。この時、フォルダは

Documents
|_t10k-images-idx3-ubyte.gz
|_t10k-labels-idx1-ubyte.gz
|_train-images-idx3-ubyte.gz
|_train-labels-idx1-ubyte.gz
|...

のようになっているはずです。

Pickle形式で保存する

上で、必要フォルダをダウンロードしましたが、これは見ればわかる通り、圧縮ファイルです。これを扱うには、gzipライブラリを使います。使い方は、ふつうのファイルを開くときと非常によく似ていて

gzip.open('ファイル名', '読み込み形式(r, wなど)')

詳しくは公式ドキュメントを読んでください。

では、ここからPickle形式での保存を行っていきます。中身がどのようなファイルなのかよく分からないので、Jupyter Notebookを使って対話的に進めていきました。

import gzip
import numpy as np

file_path = dataset_dir + key_file['train_img']    #試しにtrain_imgを見てみる
with gzip.open(file_path, 'rb') as f:
    data = np.frombuffer(f.read(), np.uint8)

len(data)    #->47040016

ここでnp.frombuffer()というのは、バッファーからnumpyの配列を生成するメソッドです。バッファーとは?という人はこのページがわかりやすいと思います。今回は画像ということで、np.uint8でデータを扱います(0-255を扱うことができRGBを表せるから)。

最終行でdataの長さを調べています。MNISTのページによると、数字の画像データは28x28のサイズをしているので、データ数は28x28の倍数であるはずです。

len(data) % (28**2)    #-> 16

16個ほどデータが余分にあることがわかります。これをカットするためには

with gzip.open(file_path, 'rb') as f:
    data = np.frombuffer(f.read(), np.uint8, offset=16)

とします。同様にlabelのほうも調べると、offset=8とすればよいこともわかります。

ではどんどん配列に変換していきましょう。

def load_img(file_name):
    file_path = dataset_dir + '/' + file_name
    with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, 784)

    return data

def load_label(file_name):
    file_path = dataset_dir + '/' + file_name
    with gzip.open(file_path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)

    return labels

dataset = {}
dataset['train_img'] = load_img(key_file['train_img'])
dataset['train_label'] = load_label(key_file['train_label'])
dataset['test_img'] = load_img(key_file['test_img'])
dataset['test_label'] = load_label(key_file['test_label'])

最後にPickle形式で保存します。これにはpickleライブラリを使います。このライブラリもいろいろなことができますが、ここでは次のメソッドを使います。

メソッド 動作
dump オブジェクトをpickle化する
load pickleオブジェクト表現を再構成

詳しい説明は公式ドキュメントを参照してください。

import pickle

save_file = dataset_dir + '/mnist.pkl'    #拡張子は.pkl
with open(save_file, 'wb') as f:
    pickle.dump(dataset, f, -1)    #-1は最も高いプロトコルバージョンで保存する
                                   #ことを指定している

保存したデータを読み込むときには、

with open(save_data, 'rb') as f:
    dataset = pickle.load(f)

とします。

簡単な前処理

データも保存できたので、簡単な前処理までやってしまいます。データはdatasetという名前で、上のコードによって読み込まれているとします。

まずはデータの確認です。

dataset['train_img'].shape    #-> (60000, 784)
dataset['train_label].shape   #-> (60000,)

28x28の画像ファイルが60000個入っていることが確認できます。試しに一番最初のデータを見てみましょう。

import matplotlib.pyplot as plt

example = dataset['train_img'][0].reshape((28, 28))

plt.imshow(example)
plt.show()

多分これは5ですね。labelを見てみましょう。

dataset['train_label'][0]    #-> 5

はい、5でした。ラベルは、正解の数字がそのまま入っているようですね。ニューラルネットワークで学習させるときには、このような形式よりもone-hot形式のほうが良いので、変換してみましょう。one-hot形式というのは、正解ラベルの場所だけが1でほかは0である形式のことを言います。今回の場合には下のような対応関係があります。

#正解の数字   #one-hot表現
1            [1,0,0,0,0,0,0,0,0,0]
2            [0,1,0,0,0,0,0,0,0,0]
3            [0,0,1,0,0,0,0,0,0,0]
4            [0,0,0,1,0,0,0,0,0,0]
...
9            [0,0,0,0,0,0,0,0,0,1]

コードは次のようになります。

def to_one_hot(label):
    T = np.zeros((label.size, 10))
    for i in range(label.size):
        T[i][label[i]] = 1
    return T

dataset['train_label'] = to_one_hot(dataset['train_label'])
dataset['train_label'].shape    #-> (60000, 10)

次に、正規化をしてみます。正規化とは、データをある一定の範囲の中に納まるようにする処理のことです。今回は各ピクセルの値は0-255なので、全体を255で割って0-1の間に納まるようにします。

def normalize(key):
    dataset[key] = dataset[key].astype(np.float32)
    dataset[key] /= 255

    return dataset[key]

dataset['train_img'] = normalize('train_img')

似たような処理には、データ全体の平均を0に、標準偏差を1にする標準化や、データ全体の分布の形状を一様にする白色化などがあります。

参考資料

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away