何をするのか
手書き数字の認識のデータセットとして有名なMNISTのダウンロード、簡単な前処理の方法を紹介します。ダウンロードしたデータは、pickle形式で保存して、すぐにロードできるようにします。
ここで紹介するコードは「ゼロから作る Deep Learning」を参考にしています。ここで、紹介するものは、「ゼロから作るDeep Learning」の本筋ではなくて、付属のコードに関するものです。しかし、読んでいて今後応用できる可能性が大きく、また得るものも多かったのでここで紹介します。
実行環境
- Windows10
- Python 3.6.5 (Anaconda)
- Jupyter Notebook
Pickleとは
上で、pickel形式を太字にしました。これはプログラム実行中のオブジェクトをファイルとして保存するものです。なので、この形式で保存しておいて、後で再びロードすると、保存した時点でのオブジェクトの状態を再現することができます。
ちなみに「pickle」とは漬物の意味で、データを漬物のように長期保存に適した形で保存するということを表しているのだと思います(多分)。
データをダウンロードする
データをホームページからダウンロードします。アドレスは
http://yann.lecun.com/exdb/mnist/
です。ここからファイルをダウンロードするわけですが、ここからPythonを使って行います。
import urllib.request
url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
方法としては、urllib
パッケージを使ってホームページからデータをとってきて保存します。上のページにアクセスするとわかりますが、4種類のフォルダーがダウンロードできるようになっています。それぞれを'train_img', 'train_label', 'test_img', 'test_label'
というキーを付けて区別します。
では、実際にダウンロードします。
dataset_dir = 'C:/Users/usr/Documents' #データを保存する場所
for v in key_file.values():
file_path = dataset_dir + '/' + v
urllib.request.urlretrieve(url_base + v, file_path)
少し待てばダウンロードが終了します。この時、フォルダは
Documents
|_t10k-images-idx3-ubyte.gz
|_t10k-labels-idx1-ubyte.gz
|_train-images-idx3-ubyte.gz
|_train-labels-idx1-ubyte.gz
|...
のようになっているはずです。
Pickle形式で保存する
上で、必要フォルダをダウンロードしましたが、これは見ればわかる通り、圧縮ファイルです。これを扱うには、gzip
ライブラリを使います。使い方は、ふつうのファイルを開くときと非常によく似ていて
gzip.open('ファイル名', '読み込み形式(r, wなど)')
詳しくは公式ドキュメントを読んでください。
では、ここからPickle形式での保存を行っていきます。中身がどのようなファイルなのかよく分からないので、Jupyter Notebookを使って対話的に進めていきました。
import gzip
import numpy as np
file_path = dataset_dir + key_file['train_img'] #試しにtrain_imgを見てみる
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8)
len(data) #->47040016
ここでnp.frombuffer()
というのは、バッファーからnumpyの配列を生成するメソッドです。バッファーとは?という人はこのページがわかりやすいと思います。今回は画像ということで、np.uint8
でデータを扱います(0-255を扱うことができRGBを表せるから)。
最終行でdata
の長さを調べています。MNISTのページによると、数字の画像データは28x28のサイズをしているので、データ数は28x28の倍数であるはずです。
len(data) % (28**2) #-> 16
16個ほどデータが余分にあることがわかります。これをカットするためには
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
とします。同様にlabelのほうも調べると、offset=8
とすればよいこともわかります。
ではどんどん配列に変換していきましょう。
def load_img(file_name):
file_path = dataset_dir + '/' + file_name
with gzip.open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 784)
return data
def load_label(file_name):
file_path = dataset_dir + '/' + file_name
with gzip.open(file_path, 'rb') as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)
return labels
dataset = {}
dataset['train_img'] = load_img(key_file['train_img'])
dataset['train_label'] = load_label(key_file['train_label'])
dataset['test_img'] = load_img(key_file['test_img'])
dataset['test_label'] = load_label(key_file['test_label'])
最後にPickle形式で保存します。これにはpickle
ライブラリを使います。このライブラリもいろいろなことができますが、ここでは次のメソッドを使います。
メソッド | 動作 |
---|---|
dump | オブジェクトをpickle化する |
load | pickleオブジェクト表現を再構成 |
詳しい説明は公式ドキュメントを参照してください。
import pickle
save_file = dataset_dir + '/mnist.pkl' #拡張子は.pkl
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1) #-1は最も高いプロトコルバージョンで保存する
#ことを指定している
保存したデータを読み込むときには、
with open(save_data, 'rb') as f:
dataset = pickle.load(f)
とします。
簡単な前処理
データも保存できたので、簡単な前処理までやってしまいます。データはdataset
という名前で、上のコードによって読み込まれているとします。
まずはデータの確認です。
dataset['train_img'].shape #-> (60000, 784)
dataset['train_label].shape #-> (60000,)
28x28の画像ファイルが60000個入っていることが確認できます。試しに一番最初のデータを見てみましょう。
import matplotlib.pyplot as plt
example = dataset['train_img'][0].reshape((28, 28))
plt.imshow(example)
plt.show()
多分これは5ですね。labelを見てみましょう。
dataset['train_label'][0] #-> 5
はい、5でした。ラベルは、正解の数字がそのまま入っているようですね。ニューラルネットワークで学習させるときには、このような形式よりもone-hot形式のほうが良いので、変換してみましょう。one-hot形式というのは、正解ラベルの場所だけが1でほかは0である形式のことを言います。今回の場合には下のような対応関係があります。
#正解の数字 #one-hot表現
1 [1,0,0,0,0,0,0,0,0,0]
2 [0,1,0,0,0,0,0,0,0,0]
3 [0,0,1,0,0,0,0,0,0,0]
4 [0,0,0,1,0,0,0,0,0,0]
...
9 [0,0,0,0,0,0,0,0,0,1]
コードは次のようになります。
def to_one_hot(label):
T = np.zeros((label.size, 10))
for i in range(label.size):
T[i][label[i]] = 1
return T
dataset['train_label'] = to_one_hot(dataset['train_label'])
dataset['train_label'].shape #-> (60000, 10)
次に、正規化をしてみます。正規化とは、データをある一定の範囲の中に納まるようにする処理のことです。今回は各ピクセルの値は0-255なので、全体を255で割って0-1の間に納まるようにします。
def normalize(key):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255
return dataset[key]
dataset['train_img'] = normalize('train_img')
似たような処理には、データ全体の平均を0に、標準偏差を1にする標準化や、データ全体の分布の形状を一様にする白色化などがあります。