はじめに
MNISTをつかったアルゴリズムの検討とか、論文のキャッチアップとか、いろいろと皆さんやると思います。kerasを使ったり、pytorchを使ったり、scikit-learnを使ったりと、検討方法は色々とあると思いますが、その度にmnistをダウンロードしたりしている人も多いのではないでしょうか。
少なくとも少し前の僕はそうでした。無駄でした。
当たり前の話なのですが、ちゃんと管理して使うとなにかと効率がいいよ、っていう素人的なお話し。MNISTをとりあえず保存し、それを読み込む関数を用意したいと思います。
Torchを使った保存
MNISTはまぁどんな形でダウンロードしてもいいのですが、今回はpytorchを使った例を載せます。
import os
import numpy as np
from torchvision import datasets, transforms
mnist_train = datasets.MNIST(
root='/Users/user_name/Downloads/Pytorch_data/MNIST',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
mnist_test = datasets.MNIST(
root='/Users/user_name/Downloads/Pytorch_data/MNIST',
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor()
])
)
とりあえずこれで、root
に指定したディレクトリにダウンロードされたMNISTが保存されます。ここにあれば毎回pytorchで使う分には再ダウンロードされないんだけど、keras使うと再ダウンロードされたり、自分でnumpyで中身を見て色々やりたいと思うときに、いちいちtorchで読み込むのもアレなので、これを自分のデータセットとして保存しましょう。
x_train = mnist_train.train_data.numpy()
y_train = mnist_train.train_labels.numpy()
x_test = mnist_test.test_data.numpy()
y_test = mnist_test.test_labels.numpy()
save_root = '/Users/user_name/workspace/dataset/MNIST'
if not os.path.exists(save_root):
os.mkdir(save_root)
np.save(os.path.join(save_root, 'mnist_train_imgs'), x_train)
np.save(os.path.join(save_root, 'mnist_train_labels'), y_train)
np.save(os.path.join(save_root, 'mnist_test_imgs'), x_test)
np.save(os.path.join(save_root, 'mnist_test_labels'), y_test)
当たり前の話だけど、まぁこれだけ。
とりあえずこれでデータセットをnumpy形式で保存できました。
じゃあこれを使いやすくLoadする関数でも作ってあげましょう。
保存したMNISTの効率的な読み込み
せっかく保存したので、人の子ならうまく使いたいと思うでしょう。
単なるロードだけではなく、いくつかの機能を持たせて便利に読み込みましょう。
読み込み
まずは保存したMNISTの読み込み部分。
これはシンプルにdata_root
を与えてファイルを読み込みます。
_x_train = np.load(os.path.join(data_root, 'mnist_train_imgs.npy'))
_y_train = np.load(os.path.join(data_root, 'mnist_train_labels.npy'))
_x_test = np.load(os.path.join(data_root, 'mnist_test_imgs.npy'))
_y_test = np.load(os.path.join(data_root, 'mnist_test_labels.npy'))
サイズ変更
追加していきたいのが、MNISTの画像サイズの変更。
MNISTは28x28のモノクロ画像だけど、random forestのアルゴにかけるときとかは一次元で入力する必要があったりしますので、簡単なReshapeを用意してあげます。
# Reshape
print('-----Shape of image.-----')
if img_shape == 0:
print('Original size.')
elif img_shape == 1:
print('Reshape to 1 dimension')
_x_train = np.reshape(_x_train, (-1, 28 * 28))
_x_test = np.reshape(_x_test, (-1, 28 * 28))
print('shape(x_train)=', _x_train.shape)
print('shape(x_test)=', _x_test.shape)
ついでにラベルのほうも、One-Hot encodingにしたいとか、したくないとか、そんな要望があると思うので、一緒に変換しましょう。
# One-Hot Encoding
print('\n-----Label format-----')
if one_hot == 0:
print('Original data. labels are integer value.')
elif one_hot == 1:
print('Labels are One-Hot Encoding.')
n_labels = len(np.unique(_y_train))
_y_train = np.eye(n_labels)[_y_train]
_y_test = np.eye(n_labels)[_y_test]
print('shape(y_train)=', _y_train.shape)
print('shape(y_test)=', _y_test.shape)
正規化など
よく機械学習を使うときに行われるデータの入力範囲の変更。
MNISTでもいろんなパターンを見ることがあります。
例えば255で割るだけ、平均を0に分散を1にといういわゆる正規化など。
そういうものも用意しておくと毎度の検討が楽になると思います。
# Normalize
print('\n-----Normalize-----')
if normalize == 0:
print('Original data. 8bit gray scale. (0 to 255)')
elif normalize == 1:
print('Normalize 0 to 1 value')
_x_train /= 255
_x_test /= 255
elif normalize == 2:
print('Normalize mean to 0 and variance to 1.')
mean_train = np.mean(_x_train)
var_train = np.var(_x_train)
_x_train = (_x_train - mean_train) / var_train
_x_test = (_x_test - mean_train) / var_train
おまけ
あとは人によるかもしれませんが、よく使うフレームワークに合わせて、フォーマットの変更が必要な人は、最後にその変換を用意しておくといいかもしれません。僕はpytorchをよく使うので用意してあります。
# Change data format
if data_format.lower() == 'numpy':
print('Return umpy format.')
elif data_format.lower() == 'torch':
print('Return torch format.')
_x_train = torch.from_numpy(_x_train)
_y_train = torch.from_numpy(_y_train)
_x_test = torch.from_numpy(_x_test)
_y_test = torch.from_numpy(_y_test)
おまけ その2
必要があれば簡単な画像表示機能も持たせておきましょう。
# Preview
if preview:
fig = plt.figure()
for i in range(9):
ax = fig.add_subplot(3, 3, i + 1)
ax.imshow(_x_train[100 * i], 'gray')
plt.tight_layout()
plt.show()
こういう処理を寄せ集めたMNISTのロード関数を用意しておくと、結構便利です。また、その人なりに必要なものを色々と追加しておくとより便利かもしれません。
全体のコード
以下に全体のコードを載せておきます。
まずは保存。
import os
import numpy as np
from torchvision import datasets, transforms
def save_mnist():
mnist_train = datasets.MNIST(
root='/Users/user_name/Downloads/Pytorch_data/MNIST',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
mnist_test = datasets.MNIST(
root='/Users/user_name/Downloads/Pytorch_data/MNIST',
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor()
])
)
x_train = mnist_train.train_data.numpy()
y_train = mnist_train.train_labels.numpy()
x_test = mnist_test.test_data.numpy()
y_test = mnist_test.test_labels.numpy()
save_root = '/Users/user_name/workspace/dataset/MNIST'
if not os.path.exists(save_root):
os.mkdir(save_root)
np.save(os.path.join(save_root, 'mnist_train_imgs'), x_train)
np.save(os.path.join(save_root, 'mnist_train_labels'), y_train)
np.save(os.path.join(save_root, 'mnist_test_imgs'), x_test)
np.save(os.path.join(save_root, 'mnist_test_labels'), y_test)
次は読み込み。
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
def load_mnist(data_root, preview=False, img_shape=0, normalize=0, one_hot=0, data_format='numpy'):
_x_train = np.load(os.path.join(data_root, 'mnist_train_imgs.npy'))
_y_train = np.load(os.path.join(data_root, 'mnist_train_labels.npy'))
_x_test = np.load(os.path.join(data_root, 'mnist_test_imgs.npy'))
_y_test = np.load(os.path.join(data_root, 'mnist_test_labels.npy'))
# Reshape
print('-----Shape of image.-----')
if img_shape == 0:
print('Original size.')
elif img_shape == 1:
print('Reshape to 1 dimension')
_x_train = np.reshape(_x_train, (-1, 28 * 28))
_x_test = np.reshape(_x_test, (-1, 28 * 28))
print('shape(x_train)=', _x_train.shape)
print('shape(x_test)=', _x_test.shape)
# One-Hot Encoding
print('\n-----Label format-----')
if one_hot == 0:
print('Original data. labels are integer value.')
elif one_hot == 1:
print('Labels are One-Hot Encoding.')
n_labels = len(np.unique(_y_train))
_y_train = np.eye(n_labels)[_y_train]
_y_test = np.eye(n_labels)[_y_test]
print('shape(y_train)=', _y_train.shape)
print('shape(y_test)=', _y_test.shape)
# Normalize
print('\n-----Normalize-----')
if normalize == 0:
print('Original data. 8bit gray scale. (0 to 255)')
elif normalize == 1:
print('Normalize 0 to 1 value')
_x_train /= 255
_x_test /= 255
elif normalize == 2:
print('Normalize mean to 0 and variance to 1.')
mean_train = np.mean(_x_train)
var_train = np.var(_x_train)
_x_train = (_x_train - mean_train) / var_train
_x_test = (_x_test - mean_train) / var_train
# Change data format
if data_format.lower() == 'numpy':
print('Return numpy format.')
elif data_format.lower() == 'torch':
print('Return torch format.')
_x_train = torch.from_numpy(_x_train)
_y_train = torch.from_numpy(_y_train)
_x_test = torch.from_numpy(_x_test)
_y_test = torch.from_numpy(_y_test)
# Preview
if preview:
fig = plt.figure()
for i in range(9):
ax = fig.add_subplot(3, 3, i + 1)
ax.imshow(_x_train[100 * i], 'gray')
plt.tight_layout()
plt.show()
return _x_train, _y_train, _x_test, _y_test
if __name__ == '__main__':
data_path = '/Users/user_name/workspace/dataset/MNIST'
x_train, y_train, x_test, y_test = load_mnist(
data_root=data_path,
img_shape=1,
one_hot=1,
normalize=2,
data_format='torch',
preview=False
)
print(x_train)
exit()
まとめ
プログラム的には全く大したことはやってないけど、こういうのって個人的には意外と大事で抑えておきたいところだと思ってます。ちなみにMNIST以外でも、データ読み出しはこんなふうにしておくといいと感じています。ただ、気に入らないのは、if文の分岐ですかね・・こういった例で効率のいい条件分岐の仕方を知っている人がいたら教えてください。