LoginSignup
3
3

More than 5 years have passed since last update.

はじめに

MNISTをつかったアルゴリズムの検討とか、論文のキャッチアップとか、いろいろと皆さんやると思います。kerasを使ったり、pytorchを使ったり、scikit-learnを使ったりと、検討方法は色々とあると思いますが、その度にmnistをダウンロードしたりしている人も多いのではないでしょうか。

少なくとも少し前の僕はそうでした。無駄でした。

当たり前の話なのですが、ちゃんと管理して使うとなにかと効率がいいよ、っていう素人的なお話し。MNISTをとりあえず保存し、それを読み込む関数を用意したいと思います。

Torchを使った保存

MNISTはまぁどんな形でダウンロードしてもいいのですが、今回はpytorchを使った例を載せます。

import os
import numpy as np
from torchvision import datasets, transforms
mnist_train = datasets.MNIST(
    root='/Users/user_name/Downloads/Pytorch_data/MNIST',
    download=True,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
mnist_test = datasets.MNIST(
    root='/Users/user_name/Downloads/Pytorch_data/MNIST',
    download=True,
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

とりあえずこれで、rootに指定したディレクトリにダウンロードされたMNISTが保存されます。ここにあれば毎回pytorchで使う分には再ダウンロードされないんだけど、keras使うと再ダウンロードされたり、自分でnumpyで中身を見て色々やりたいと思うときに、いちいちtorchで読み込むのもアレなので、これを自分のデータセットとして保存しましょう。

x_train = mnist_train.train_data.numpy()
y_train = mnist_train.train_labels.numpy()
x_test = mnist_test.test_data.numpy()
y_test = mnist_test.test_labels.numpy()

save_root = '/Users/user_name/workspace/dataset/MNIST'
if not os.path.exists(save_root):
    os.mkdir(save_root)

np.save(os.path.join(save_root, 'mnist_train_imgs'), x_train)
np.save(os.path.join(save_root, 'mnist_train_labels'), y_train)
np.save(os.path.join(save_root, 'mnist_test_imgs'), x_test)
np.save(os.path.join(save_root, 'mnist_test_labels'), y_test)

当たり前の話だけど、まぁこれだけ。
とりあえずこれでデータセットをnumpy形式で保存できました。
じゃあこれを使いやすくLoadする関数でも作ってあげましょう。

保存したMNISTの効率的な読み込み

せっかく保存したので、人の子ならうまく使いたいと思うでしょう。
単なるロードだけではなく、いくつかの機能を持たせて便利に読み込みましょう。

読み込み

まずは保存したMNISTの読み込み部分。
これはシンプルにdata_rootを与えてファイルを読み込みます。

_x_train = np.load(os.path.join(data_root, 'mnist_train_imgs.npy'))
_y_train = np.load(os.path.join(data_root, 'mnist_train_labels.npy'))
_x_test = np.load(os.path.join(data_root, 'mnist_test_imgs.npy'))
_y_test = np.load(os.path.join(data_root, 'mnist_test_labels.npy'))

サイズ変更

追加していきたいのが、MNISTの画像サイズの変更。
MNISTは28x28のモノクロ画像だけど、random forestのアルゴにかけるときとかは一次元で入力する必要があったりしますので、簡単なReshapeを用意してあげます。

# Reshape
print('-----Shape of image.-----')
if img_shape == 0:
    print('Original size.')
elif img_shape == 1:
    print('Reshape to 1 dimension')
    _x_train = np.reshape(_x_train, (-1, 28 * 28))
    _x_test = np.reshape(_x_test, (-1, 28 * 28))

print('shape(x_train)=', _x_train.shape)
print('shape(x_test)=', _x_test.shape)

ついでにラベルのほうも、One-Hot encodingにしたいとか、したくないとか、そんな要望があると思うので、一緒に変換しましょう。

# One-Hot Encoding
print('\n-----Label format-----')
if one_hot == 0:
    print('Original data. labels are integer value.')
elif one_hot == 1:
    print('Labels are One-Hot Encoding.')
    n_labels = len(np.unique(_y_train))
    _y_train = np.eye(n_labels)[_y_train]
    _y_test = np.eye(n_labels)[_y_test]

print('shape(y_train)=', _y_train.shape)
print('shape(y_test)=', _y_test.shape)

正規化など

よく機械学習を使うときに行われるデータの入力範囲の変更。
MNISTでもいろんなパターンを見ることがあります。
例えば255で割るだけ、平均を0に分散を1にといういわゆる正規化など。
そういうものも用意しておくと毎度の検討が楽になると思います。

# Normalize
print('\n-----Normalize-----')
if normalize == 0:
    print('Original data. 8bit gray scale. (0 to 255)')
elif normalize == 1:
    print('Normalize 0 to 1 value')
    _x_train /= 255
    _x_test /= 255
elif normalize == 2:
    print('Normalize mean to 0 and variance to 1.')
    mean_train = np.mean(_x_train)
    var_train = np.var(_x_train)
    _x_train = (_x_train - mean_train) / var_train
    _x_test = (_x_test - mean_train) / var_train

おまけ

あとは人によるかもしれませんが、よく使うフレームワークに合わせて、フォーマットの変更が必要な人は、最後にその変換を用意しておくといいかもしれません。僕はpytorchをよく使うので用意してあります。

# Change data format
if data_format.lower() == 'numpy':
    print('Return umpy format.')
elif data_format.lower() == 'torch':
    print('Return torch format.')
    _x_train = torch.from_numpy(_x_train)
    _y_train = torch.from_numpy(_y_train)
    _x_test = torch.from_numpy(_x_test)
    _y_test = torch.from_numpy(_y_test)

おまけ その2

必要があれば簡単な画像表示機能も持たせておきましょう。

# Preview
if preview:
    fig = plt.figure()
    for i in range(9):
        ax = fig.add_subplot(3, 3, i + 1)
        ax.imshow(_x_train[100 * i], 'gray')
    plt.tight_layout()
    plt.show()

こういう処理を寄せ集めたMNISTのロード関数を用意しておくと、結構便利です。また、その人なりに必要なものを色々と追加しておくとより便利かもしれません。

全体のコード

以下に全体のコードを載せておきます。
まずは保存。

save_mnist.py
import os
import numpy as np
from torchvision import datasets, transforms


def save_mnist():
    mnist_train = datasets.MNIST(
        root='/Users/user_name/Downloads/Pytorch_data/MNIST',
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )
    mnist_test = datasets.MNIST(
        root='/Users/user_name/Downloads/Pytorch_data/MNIST',
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )

    x_train = mnist_train.train_data.numpy()
    y_train = mnist_train.train_labels.numpy()
    x_test = mnist_test.test_data.numpy()
    y_test = mnist_test.test_labels.numpy()

    save_root = '/Users/user_name/workspace/dataset/MNIST'
    if not os.path.exists(save_root):
        os.mkdir(save_root)

    np.save(os.path.join(save_root, 'mnist_train_imgs'), x_train)
    np.save(os.path.join(save_root, 'mnist_train_labels'), y_train)
    np.save(os.path.join(save_root, 'mnist_test_imgs'), x_test)
    np.save(os.path.join(save_root, 'mnist_test_labels'), y_test)

次は読み込み。

load_mnist.py
import os
import numpy as np
import matplotlib.pyplot as plt
import torch


def load_mnist(data_root, preview=False, img_shape=0, normalize=0, one_hot=0, data_format='numpy'):
    _x_train = np.load(os.path.join(data_root, 'mnist_train_imgs.npy'))
    _y_train = np.load(os.path.join(data_root, 'mnist_train_labels.npy'))
    _x_test = np.load(os.path.join(data_root, 'mnist_test_imgs.npy'))
    _y_test = np.load(os.path.join(data_root, 'mnist_test_labels.npy'))

    # Reshape
    print('-----Shape of image.-----')
    if img_shape == 0:
        print('Original size.')
    elif img_shape == 1:
        print('Reshape to 1 dimension')
        _x_train = np.reshape(_x_train, (-1, 28 * 28))
        _x_test = np.reshape(_x_test, (-1, 28 * 28))

    print('shape(x_train)=', _x_train.shape)
    print('shape(x_test)=', _x_test.shape)

    # One-Hot Encoding
    print('\n-----Label format-----')
    if one_hot == 0:
        print('Original data. labels are integer value.')
    elif one_hot == 1:
        print('Labels are One-Hot Encoding.')
        n_labels = len(np.unique(_y_train))
        _y_train = np.eye(n_labels)[_y_train]
        _y_test = np.eye(n_labels)[_y_test]

    print('shape(y_train)=', _y_train.shape)
    print('shape(y_test)=', _y_test.shape)

    # Normalize
    print('\n-----Normalize-----')
    if normalize == 0:
        print('Original data. 8bit gray scale. (0 to 255)')
    elif normalize == 1:
        print('Normalize 0 to 1 value')
        _x_train /= 255
        _x_test /= 255
    elif normalize == 2:
        print('Normalize mean to 0 and variance to 1.')
        mean_train = np.mean(_x_train)
        var_train = np.var(_x_train)
        _x_train = (_x_train - mean_train) / var_train
        _x_test = (_x_test - mean_train) / var_train

    # Change data format
    if data_format.lower() == 'numpy':
        print('Return numpy format.')
    elif data_format.lower() == 'torch':
        print('Return torch format.')
        _x_train = torch.from_numpy(_x_train)
        _y_train = torch.from_numpy(_y_train)
        _x_test = torch.from_numpy(_x_test)
        _y_test = torch.from_numpy(_y_test)

    # Preview
    if preview:
        fig = plt.figure()
        for i in range(9):
            ax = fig.add_subplot(3, 3, i + 1)
            ax.imshow(_x_train[100 * i], 'gray')
        plt.tight_layout()
        plt.show()

    return _x_train, _y_train, _x_test, _y_test


if __name__ == '__main__':
    data_path = '/Users/user_name/workspace/dataset/MNIST'
    x_train, y_train, x_test, y_test = load_mnist(
        data_root=data_path,
        img_shape=1,
        one_hot=1,
        normalize=2,
        data_format='torch',
        preview=False
    )

    print(x_train)
    exit()

まとめ

プログラム的には全く大したことはやってないけど、こういうのって個人的には意外と大事で抑えておきたいところだと思ってます。ちなみにMNIST以外でも、データ読み出しはこんなふうにしておくといいと感じています。ただ、気に入らないのは、if文の分岐ですかね・・こういった例で効率のいい条件分岐の仕方を知っている人がいたら教えてください。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3