LoginSignup
4
6

More than 3 years have passed since last update.

ゼロから作る keras.datasets.cifar10 の load_data()

Posted at

TL;DR

この記事の半分は

from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

とすれば終わることです。

つまり、この Keras で実装されているメソッドを
ゼロから作ってみようというところが今回の記事の目的になります。

対象としていない読者の方

  • とりあえず cifar10 を用いた画像分類をしたい、という方
  • 既にメソッドがあるならそれを使えばいいではないか、という真っ当な考えをお持ちの方。
  • 機械学習やpythonに精通している方

勉強になるかもしれない部分

  • ゼロから作っている
  • pythonでの関数定義の仕方
  • (pythonを用いて)指定したURL先にある、名前の知っているファイルのDownloadの仕方
  • (pythonを用いた)tar.gzファイルの解凍の仕方
  • pickleファイルのunpickle化
  • np.ndarray の reshape method について
  • np.rollaxisについて
  • 画像学習データの正規化の仕方
  • NumPyを用いた one hot labelの作り方

(これらは自分が勉強になったことをつらつら並べただけです..)

本題

まずはプログラムから

先頭のコメントにある通り、このプログラムのお仕事は

  • cifar10 のファイルを (公式から) Download して
  • gzip(tar.gz) を解凍した後
  • データが pickle ファイルになっているので、unpickle して
  • reshape, nomalize, one_hot_label化 (選択可能) です。

# download => unzip => load(unpickle) => prep cifar10

# ----- import -----
import os, sys
import tarfile
import numpy as np

try:
    import urllib.request
except ImportError:
    raise ImportError('please use python 3.x ...')
try:
    import cPickle as pickle
except ImportError:
    import pickle

# ----- dir -----
cwd = os.path.abspath(__file__)
cdir = os.path.dirname(cwd)

# ----- url -----
url_base = "https://www.cs.toronto.edu/~kriz/"
file_name = "cifar-10-python.tar.gz"
unziped_file_name = "cifar-10-batches-py"

# ----- download -----
def download_cifar10(file_name=file_name):
    file_path = os.path.join(cdir, file_name)

    if os.path.exists(file_path):
        pass
    else:
        print("Downloading ", file_name, " ... ")
        urllib.request.urlretrieve(url_base + file_name, file_path)
        print("Done.")

# ----- unzip .tar.gz file -----
def unzip(zip_file=file_name):
    if os.path.exists(unziped_file_name):
        pass
    else:
        print("unzip " + zip_file + " ...")
        with tarfile.open(zip_file, 'r:gz') as tarf:
            tarf.extractall(path=cdir)
        print("Done.")

# ----- init (download & unzip) -----
def init_cifar10():
    print("initializing cifar10 ...")
    download_cifar10()
    unzip()
    print("Done.")


# ---- def unpickle function before load -----
def unpickle(target_file):
    with open(target_file, 'rb') as f:
        content = pickle.load(f, encoding="latin-1")
    return content

# ----- load -----
def load_cifar10(target_dir=unziped_file_name, reshape=True, nomalize=True, one_hot=True):
    if not os.path.exists(unziped_file_name):
        init_cifar10()

    # --- load train data & labels ---
    for i in range(1, 6):
        data_batch = os.path.join(target_dir, "data_batch_{0}".format(i))
        up_data_batch = unpickle(data_batch)
        if i == 1:
            x_train = up_data_batch['data']
            y_train = up_data_batch['labels']
        else:
            x_train = np.vstack((x_train, up_data_batch['data']))
            y_train = np.hstack((y_train, up_data_batch['labels']))

    # --- load test data & labels ---
    up_test_batch = unpickle(os.path.join(target_dir, "test_batch"))
    x_test = up_test_batch['data']
    y_test = np.array(up_test_batch['labels'])

    # --- reshape ---
    train_dim = x_train.shape[0]
    test_dim = x_test.shape[0]
    if reshape:
        x_train = x_train.reshape((train_dim, 3, 32, 32))
        x_train = np.rollaxis(x_train, 1, 4)
        x_test = x_test.reshape((test_dim, 3, 32, 32))
        x_test = np.rollaxis(x_test, 1, 4)

    # --- nomalize ---
    if nomalize:
        x_train = x_train.astype('float32')
        x_train /= 255.0
        x_test = x_test.astype('float32')
        x_test /= 255.0

    # --- one_hot ---
    if one_hot:
        class_num = 10
        y_train = np.identity(class_num)[y_train]
        y_test = np.identity(class_num)[y_test]


    return (x_train, y_train), (x_test, y_test)

def load_labelnames(target_dir=unziped_file_name):
    if not os.path.exists(target_dir):
        init_cifar10()

    batches_meta = unpickle(os.path.join(target_dir, "batches.meta"))
    label_names = batches_meta['label_names']

    return label_names

確認用 main部分


if __name__ == "__main__":
    (x_train, y_train), (x_test, y_test) = load_cifar10()

    print("x_train.shape : ", x_train.shape)
    print("y_train.shape : ", y_train.shape)
    print("x_test.shape : ", x_test.shape)
    print("y_test.shape : ", y_test.shape)

    print()

    print("type(x_train) : ", type(x_train))
    print("type(y_train) : ", type(y_train))
    print("type(x_test) : ", type(x_test))
    print("type(y_test) : ", type(y_test))

    print()

    print("x_train[0] : ", x_train[0])
    print("y_train[0] : ", y_train[0])

    print()

    label_names = load_labelnames()
    print("label_names : ", label_names)

    (x_train_plt, y_train_plt), (_, _) = load_cifar10(nomalize=False, one_hot=False)
    import matplotlib.pylab as plt
    image_index = input("さて、何番にいたしましょう (plot 中止 : z): ")
    if image_index == "z":
        pass
    else:
        idx = int(image_index)
        plt.imshow(x_train_plt[idx])
        label_num = y_train_plt[idx]
        label_name = label_names[label_num]
        plt.title("label {0} : {1}".format(label_num, label_name))
        plt.show()


解説

import


# ----- import -----
import os, sys
import tarfile
import numpy as np

try:
    import urllib.request
except ImportError:
    raise ImportError('please use python 3.x ...')
try:
    import cPickle as pickle
except ImportError:
    import pickle

必要なライブラリをimportします。
try-exceptしているところはライブラリがなかったり、
pythonのバージョンが低い場合にエラーを上げてくれるように書いてあります。

今回はcifar10のデータを公式HPに取りに行くので、urllibが必要で、
そのため、3系を使ってくださいという旨を記述しています。
2系のみなさんごめんなさい..ほら、サポート切れるとか(殴

ディレクトリ

# ----- dir -----
cwd = os.path.abspath(__file__)
cdir = os.path.dirname(cwd)

上段の記述で、現在のファイルまでの絶対パス(フルパス)が取れます。
下段の記述で、実行したファイルが格納されているディレクトリ(フォルダ)の名前が取れます。
要は絶対パスの階層の一つ上まで、ということですね。

URL

# ----- url -----
url_base = "https://www.cs.toronto.edu/~kriz/"
file_name = "cifar-10-python.tar.gz"
unziped_file_name = "cifar-10-batches-py"

この公式HPにアクセスします。

ここのURLはhttps://www.cs.toronto.edu/~kriz/cifar.html なのですが、
cifar10のダウンロードリンクはhttps://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzなので
.htmlまで指定してしまうと、合わないんですよね。

そのため、ここではベースとなる部分を定義しておきます。

また、上記のダウンロードリンクより、ファイル名がわかったので、
ダウンロード目標のファイル名も定義しておきます。

最下段のやつなんですが、展開後のファイルの名前です。
ちょっとずるいですが、一回展開して、名前をコピってきました。
後でexistのチェックとかで使うのでこれもまとめてここで定義します。

cifar10ファイルのダウンロード

# ----- download -----
def download_cifar10(file_name=file_name):
    file_path = os.path.join(cdir, file_name)

    if os.path.exists(file_path):
        pass
    else:
        print("Downloading ", file_name, " ... ")
        urllib.request.urlretrieve(url_base + file_name, file_path)
        print("Done.")

さっきのページにアクセスして
cifar10のデータセットのzipファイルをダウンロードします。

もし、既にダウンロードしてあるようだったらこの処理をスキップします。

終わったら 「だん。」 と言います。

ダウンロードしたgzipファイルのunzip


# ----- unzip .tar.gz file -----
def unzip(zip_file=file_name):
    if os.path.exists(unziped_file_name):
        pass
    else:
        print("unzip " + zip_file + " ...")
        with tarfile.open(zip_file, 'r:gz') as tarf:
            tarf.extractall(path=cdir)
        print("Done.")

上でダウンロードしたcifar10のファイルはgzipファイルなので、解凍します。
(もっと正確にいえば、tarファイルを gzipしたファイル)
(詳しい話は 僕はわからないので 本題から外れるので下にリンクだけ載せます。)

これも既に存在するようなら処理をパスします。

そして終わったら 「だん。」 と言います。

一旦ここまでの処理をまとめる


# ----- init (download & unzip) -----
def init_cifar10():
    print("initializing cifar10 ...")
    download_cifar10()
    unzip()
    print("Done.")

一旦ここまでの処理をまとめます。
目的としては、ダウンロードからzipの展開まで、localにcifar10を植え付けるだけなら一括でできた方が楽じゃね?
という ぬるい 思想によるものです。

自分のPCにcifar10のファイルがない場合は
このinit_cifar10()を実行すればダウンロードと解凍をしてくれて、
あとはunpickleして読むだけなので
以降はこの手間と時間を省くことができます。
便利でしょ。便利って言って。

後で使うunpickleしてくれる関数を定義


# ---- def unpickle function before load -----
def unpickle(target_file):
    with open(target_file, 'rb') as f:
        content = pickle.load(f, encoding="latin-1")
    return content

名前こそ異なりますが、枠組みは公式にあるものをそのまま持ってきた感じです。

unzipファイルの中には data_batch_n みたいなファイルがあって、
これが拡張子もなくてなんのファイルなのか戸惑ったのですが、
公式のダウンロードリンクの少し下の方でDataset layoutとして説明してくれていました。

The archive contains the files data_batch_1, data_batch_2, ..., data_batch_5, as well as test_batch. Each of these files is a Python "pickled" object produced with cPickle.

と、まあ data_batch_n ってファイルと、 test_batch ってファイルはpythonのpickleファイルですよ。って言ってくれています。
ちゃんと読まずにごめんなさい。英語苦手なんです。

unpickleしたら読み込む。 そして整形もしちゃう。

ここを分けるか迷ったのですが、長いので分けました。
インデントはこの記事の最初にあった、全体のプログラムを参照ください。

既存チェック


# ----- load -----
def load_cifar10(target_dir=unziped_file_name, reshape=True, nomalize=True, one_hot=True):
    if not os.path.exists(unziped_file_name):
        init_cifar10()

data_batch_n というファイルをunpickleしたら辞書型のデータが帰ってきて
dict['data']というkeyのなかに学習用画像データ
dict['labels']のなかにその画像データの正解ラベルが格納されているようです。
(5つに分けて)

ただし、いきなりこの関数をプログラム内で実行されて
localにファイルがないと読み込むものも読み込めないので、
ファイルがあるかどうかだけ先頭のif文でチェックします。

学習データ読み込み


    # --- load train data & labels ---
    for i in range(1, 6):
        data_batch = os.path.join(target_dir, "data_batch_{0}".format(i))
        up_data_batch = unpickle(data_batch)
        if i == 1:
            x_train = up_data_batch['data']
            y_train = up_data_batch['labels']
        else:
            x_train = np.vstack((x_train, up_data_batch['data']))
            y_train = np.hstack((y_train, up_data_batch['labels']))

今回は展開後に data_batch_ のファイルが 1 ~ 5まであることがわかっているので、
1番の最初 data_batch_1 だけ x_train に格納して、
2~5番目はこの1番目 x_train にスタックしていく感じにします。

基本的には画像もラベルも同じなのですが
画像の方は、行列が積み上がっていく np.vstack() を用い、
ラベルの方は、配列の長さを伸ばして積んでいく np.hstack() を用います。

間違えて np.vstack() を用いると y_train.shape したときに
(50000,) となってほしいところ (5, 10000)となります。(たぶん。)
これでなんとなくイメージがついたらあなたは数強です。(たぶん。)

テストデータも読み込む


    # --- load test data & labels ---
    up_test_batch = unpickle(os.path.join(target_dir, "test_batch"))
    x_test = up_test_batch['data']
    y_test = np.array(up_test_batch['labels'])

学習用データと同じようにテストデータも読み込みます。
特に注意点はありませんが、
ラベルがlist型なので、NumPy配列に変換して格納しておきます。

reshape

ここからはデータ整形のオプション処理です。

そのまま読み込んだだけだと
x_train.shape => (50000, 3072) となっています。

このまま使う場合は load_cifar10(reshape=False) としてください。

cifar10 は 32*32pixcel のRGB画像なのですが、
32*32*3 = 3072でフラットになってしまっている感じですね。

さらに、Keras で読み込むと (50000, 32, 32, 3) の順序なのですが
公式から拾ってきたこのデータでは (50000, 3, 32, 32) という並びらしいので、
この2点に関して整形をします。


    # --- reshape ---
    train_dim = x_train.shape[0]
    test_dim = x_test.shape[0]
    if reshape:
        x_train = x_train.reshape((train_dim, 3, 32, 32))
        x_train = np.rollaxis(x_train, 1, 4)
        x_test = x_test.reshape((test_dim, 3, 32, 32))
        x_test = np.rollaxis(x_test, 1, 4)

x_train, x_test を (データ数, 3, 32, 32) の形にreshapeします。

そして、このreshapeしたものの1次元目の「3」を
末尾(4次元目)に持って行きます。(np.rollaxis)

最初が0次元目であることに注意してください。

nomalize

0~255 の値なので、これを 0~1に収めるオプションです。

NumPyのastype() メソッドで型変換をしてfloat型に持っていき
255で割るだけです。


    # --- nomalize ---
    if nomalize:
        x_train = x_train.astype('float32')
        x_train /= 255.0
        x_test = x_test.astype('float32')
        x_test /= 255.0

one_hot化

正解ラベルをone_hot表現に変換します。


    # --- one_hot ---
    if one_hot:
        class_num = 10
        y_train = np.identity(class_num)[y_train]
        y_test = np.identity(class_num)[y_test]


    return (x_train, y_train), (x_test, y_test)

ついでにラベルに対するクラスの名前も拾ってきた。


def load_labelnames(target_dir=unziped_file_name):
    if not os.path.exists(target_dir):
        init_cifar10()

    batches_meta = unpickle(os.path.join(target_dir, "batches.meta"))
    label_names = batches_meta['label_names']

    return label_names

公式から拾ってきたzipファイル内には
batches.meta というpickleファイルがあるのですが、

この中の dict['label_names'] のなかに
クラス数字が配列のindex番号に対応したリストが入っています。

それを拾ってきます。
これで予測した際にそのラベルにより、分類class名のマッピングができそうですね。

まとめ

やはり、遠回りではありました。
すでにあるメソッドを使えば5秒で終わることに半日程度かけてしまったかなと思います。

ですが、ゼロから作ってみるとまた別のところで学びや気づきがあったり
自分で作ってみることですでにあるものに対して理解が深まったりするかと思います。

それに、自分で作れると、「それ、作ったことあります」とかいう自信に繋がったり、
似たようなことを頼まれたときに「できると思います」という言葉に説得力が生まれたりもしますよね。

これからも、何か気になったものがあれば使ってみたり、作ってみたりを続けて
基礎スキルアップに繋げられることができたらいいなと思います。

以上で今回は終わります。
ではまたいつか。

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6