[Python]CIFAR-10, CIFAR-100のデータを読み込む方法

  • 22
    いいね
  • 0
    コメント
この記事は最終更新日から1年以上が経過しています。

CIFAR-10, CIFAR-100はラベル付されたサイズが32x32のカラー画像8000万枚のデータセットです。

データ提供先よりデータをダウンロードする。
https://www.cs.toronto.edu/~kriz/cifar.html

「CIFAR-10 python version」、「CIFAR-100 python version」からダウンロードして、適当な場所に解凍する

Screenshot from 2015-12-05 01:56:04.png

Screenshot from 2015-12-05 01:56:39.png

input_cifar.py
import cPickle
import numpy as np
import os

def unpickle(file):
    fo = open(file, 'rb')
    dict = cPickle.load(fo)
    fo.close()
    return dict

def conv_data2image(data):
    return np.rollaxis(data.reshape((3,32,32)),0,3)

def get_cifar10(folder):
    tr_data = np.empty((0,32*32*3))
    tr_labels = np.empty(1)
    '''
    32x32x3
    '''
    for i in range(1,6):
        fname = os.path.join(folder, "%s%d" % ("data_batch_", i))
        data_dict = unpickle(fname)
        if i == 1:
            tr_data = data_dict['data']
            tr_labels = data_dict['labels']
        else:
            tr_data = np.vstack((tr_data, data_dict['data']))
            tr_labels = np.hstack((tr_labels, data_dict['labels']))

    data_dict = unpickle(os.path.join(folder, 'test_batch'))
    te_data = data_dict['data']
    te_labels = np.array(data_dict['labels'])

    bm = unpickle(os.path.join(folder, 'batches.meta'))
    label_names = bm['label_names']
    return tr_data, tr_labels, te_data, te_labels, label_names

def get_cifar100(folder):
    train_fname = os.path.join(folder,'train')
    test_fname  = os.path.join(folder,'test')
    data_dict = unpickle(train_fname)
    train_data = data_dict['data']
    train_fine_labels = data_dict['fine_labels']
    train_coarse_labels = data_dict['coarse_labels']

    data_dict = unpickle(test_fname)
    test_data = data_dict['data']
    test_fine_labels = data_dict['fine_labels']
    test_coarse_labels = data_dict['coarse_labels']

    bm = unpickle(os.path.join(folder, 'meta'))
    clabel_names = bm['coarse_label_names']
    flabel_names = bm['fine_label_names']

    return train_data, np.array(train_coarse_labels), np.array(train_fine_labels), test_data, np.array(test_coarse_labels), np.array(test_fine_labels), clabel_names, flabel_names

if __name__ == '__main__':
    datapath = "./data/cifar-10-batches-py"
    datapath2 = "./data/cifar-100-python"

    tr_data10, tr_labels10, te_data10, te_labels10, label_names10 = get_cifar10(datapath)
    tr_data100, tr_clabels100, tr_flabels100, te_data100, te_clabels100, te_flabels100, clabel_names100, flabel_names100 = get_cifar100(datapath2)

上のコードをinput_cifar.pyに貼り付けて、input_cifar.pyがあるフォルダーにdataフォルダーを作成して、Datasetをそこに置く
input_cifar.pyを実行すると下記のようになる。

CIFAR-10

ipython
In [1]: %run input_cifar.py
In [2]: tr_data10.shape
Out[2]: (50000, 3072)
In [3]: tr_labels10.shape
Out[3]: (50000,)
In [4]: te_data10.shape
Out[4]: (10000, 3072)
In [5]: te_labels10.shape
Out[5]: (10000,)
In [6]: label_names10
Out[6]: 
['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

CIFAR-10,CIFAR-100ではデータが、学習データ50000個とテストデータ10000個に分けられている。
学習データの0番目を取り出すときは下記のようにする。

ipython
In [7]: img0 = tr_data10[0]

画像はカラー画像でサイズが32x32になっている。
データはPlane形式でR,G,Bの順番に格納されている。
先頭から1024までがR Planeで、そこからまた1024までがG Planeで、そこから最後までがB Planeになっている。

画像を表示させるときは、データが1列になっているので32x32x3に並べ替えないといけない。
scikit-imageのimshowを使う場合は、R,G,B,R,G,Bの順番で並べればいいので下記のようにする。

ipython
In [8]: img0 = img0.reshape((3,32,32))
In [9]: img0.shape
Out[9]: (3, 32, 32)
In [10]: import numpy as np
In [11]: img1 = np.rollaxis(img0, 0, 3)
In [12]: img1.shape
Out[12]: (32, 32, 3)
In [13]: from skimage import io
In [14]: io.imshow(img1)
In [15]: io.show()

figure_1.png

0番目はラベルを見るとfrogということだが、32x32に縮小されているので見てもよくわからない。

CIFAR-100

CIFAR-100では画像が100個のclassのカテゴリに分けられていて、その100 classがさらに20個のsuperclassにグルーピングされている。
superclassとclassは下記の通り。
データの格納方法はCIFAR-10と同じ。

Superclass Classes
aquatic mammals beaver, dolphin, otter, seal, whale
fish aquarium fish, flatfish, ray, shark, trout
flowers orchids, poppies, roses, sunflowers, tulips
food containers bottles, bowls, cans, cups, plates
fruit and vegetables apples, mushrooms, oranges, pears, sweet peppers
household electrical devices clock, computer keyboard, lamp, telephone, television
household furniture bed, chair, couch, table, wardrobe
insects bee, beetle, butterfly, caterpillar, cockroach
large carnivores bear, leopard, lion, tiger, wolf
large man-made outdoor things bridge, castle, house, road, skyscraper
large natural outdoor scenes cloud, forest, mountain, plain, sea
large omnivores and herbivores camel, cattle, chimpanzee, elephant, kangaroo
medium-sized mammals fox, porcupine, possum, raccoon, skunk
non-insect invertebrates crab, lobster, snail, spider, worm
people baby, boy, girl, man, woman
reptiles crocodile, dinosaur, lizard, snake, turtle
small mammals hamster, mouse, rabbit, shrew, squirrel
trees maple, oak, palm, pine, willow
vehicles 1 bicycle, bus, motorcycle, pickup truck, train
vehicles 2 lawn-mower, rocket, streetcar, tank, tractor

Superclassのラベル名は、clabel_names100に、classのラベル名はflabel_names100に入っている。

ipython
In [6]: len(clabel_names100)
Out[6]: 20
In [7]: len(flabel_names100)
Out[7]: 100
In [8]: clabel_names100
Out[8]: 
['aquatic_mammals',
 'fish',
 'flowers',
 'food_containers',
 'fruit_and_vegetables',
 'household_electrical_devices',


 'reptiles',
 'small_mammals',
 'trees',
 'vehicles_1',
 'vehicles_2']
In [9]: flabel_names100
Out[9]: 
['apple',
 'aquarium_fish',
 'baby',
 'bear',
 'beaver',
 'bed',
 'bee',
 'beetle',
 'bicycle',
 'bottle',


 'willow_tree',
 'wolf',
 'woman',
 'worm']
In [10]: