LoginSignup
20
22

More than 5 years have passed since last update.

Chainerのtrainerを使ってCIFAR-10の分類に挑戦したかった

Last updated at Posted at 2016-08-06

はじめに

先日こちらの記事からChainerがすごく簡潔に書けるようになったと知り、前から試してみたかったCIFAR-10の画像分類に挑戦してみました
...と、書きたかったのですが、貧弱CPU環境しか持っていないので、実行確認までできていません
1日中動かして2epochくらい進んだのでおそらく正しい...はず^^;
実装に関してはこちらのブログを参考にさせていただきました

実装

CIFAR-10の画像を読み込んでくる

ここからCIFAR-10のデータをダウンロードして読み込みます。pickleのようなので下記の関数で読み込みます。

def unpickle(file):
    fp = open(file, 'rb')
    if sys.version_info.major == 2:
        data = pickle.load(fp)
    elif sys.version_info.major == 3:
        data = pickle.load(fp, encoding='latin-1')                                                                                    
    fp.close()

    return data

ニューラルネットワーク

先ほど紹介したブログを参考にさせていただきました。
この辺に関しては未だにどうやって設計していいかよくわかってないです...

class Cifar10Model(chainer.Chain):

    def __init__(self):
        super(Cifar10Model,self).__init__(
                conv1 = F.Convolution2D(3, 32, 3, pad=1),
                conv2 = F.Convolution2D(32, 32, 3, pad=1),
                conv3 = F.Convolution2D(32, 32, 3, pad=1),
                conv4 = F.Convolution2D(32, 32, 3, pad=1),
                conv5 = F.Convolution2D(32, 32, 3, pad=1),
                conv6 = F.Convolution2D(32, 32, 3, pad=1),
                l1 = L.Linear(512, 512),
                l2 = L.Linear(512,10))

    def __call__(self, x, train=True):
        h = F.relu(self.conv1(x))
        h = F.max_pooling_2d(F.relu(self.conv2(h)), 2)
        h = F.relu(self.conv3(h))
        h = F.max_pooling_2d(F.relu(self.conv4(h)), 2)
        h = F.relu(self.conv5(h))
        h = F.max_pooling_2d(F.relu(self.conv6(h)), 2)
        h = F.dropout(F.relu(self.l1(h)), train=train)
        return self.l2(h)

データ読み込み

ここで少しつまりました。
Chainerの新機能trainerを使う際に学習したいデータをiteratorに渡すのですが、チュートリアルなどでは

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)

のようになっており、ラベルなどはどうやって渡すのかわかりませんでした。
いろいろ調べた結果、Tuple_datasetを使えばよいということがわかりました。

train = chainer.tuple_dataset.TupleDataset(train_data, train_label)

のようにするといいようです。

下記、読み込み部分全コードとなってます。

x_train = None
y_train = []
for i in range(1,6):
    data_dic = unpickle("cifar-10-batches-py/data_batch_{}".format(i))
    if i == 1:
        x_train = data_dic['data']
    else:
        x_train = np.vstack((x_train, data_dic['data']))
    y_train += data_dic['labels']

test_data_dic = unpickle("cifar-10-batches-py/test_batch")
x_test = test_data_dic['data']
x_test = x_test.reshape(len(x_test),3,32,32)
y_test = np.array(test_data_dic['labels'])
x_train = x_train.reshape((len(x_train),3, 32, 32))
y_train = np.array(y_train)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
x_train /= 255
x_test/=255                                                                                                                     
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)

train = tuple_dataset.TupleDataset(x_train, y_train)
test = tuple_dataset.TupleDataset(x_test, y_test)

学習部分

先ほど定義したニューラルネットワークで学習をしています。コードはチュートリアルのMNISTを少しいじっただけです。
めちゃくちゃ簡潔に書けて驚きです。


model = L.Classifier(Cifar10Model())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (40, 'epoch'), out="logs")
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())                                                                                          
trainer.run()

結果

実行するとプログレスバーが出てきて、学習の進み具合を教えてくれます。

Screenshot from 2016-08-06 18:04:00.png

Estimated time to finish: 6 days

あきらめました
(2016.08.15 修正)
がんばりました

出力されたlogをdictionaryに読み込みmatplotlibでグラフ化しました

figure_1.png

figure_2.png

おわりに

きちんと結果が出たのは確認できませんでしたが、trainerの使い方の勉強にはなりました
やはりDeep Learningの勉強をするにはGPUは必須ですね

20
22
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
22