LoginSignup
9
9

More than 5 years have passed since last update.

今更ながらKerasでCIFAR-10の画像認識をやってみた - 学習編 -

Last updated at Posted at 2017-08-29

はじめに

  • Tensorflowでは一度CNNの実装を行いましたが、今回はどうやら一番評価の高そうなKerasというフレームワークで実装してみました。
  • 使ってみた感覚だと、Tensorflowに比べて便利なライブラリが色々内包されていて、簡単に実装できた感じでした。
  • 理由としては、どうやら細かな計算処理がブラックボックス化されているため、見た目としては何も考えなくとも理解できる感じですが、細かなチューニングなどは行えなさそうな感じでした。
  • 上記を踏まえてとりあえず動かしてみたい初心者向き。
  • OpenCV3がやっと使えるようになったのにKerasが画像操作もやってくれているので出番なし。
  • 1エポックが7分ほどかかかるので、最初は10エポックくらいで試したほうがよいかもです。100エポックで約10〜12時間程度だと思われます。

参考

CIFAR-10とは

  • 32x32ピクセルのカラー画像のデータセット。
  • airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truckの10種類で訓練用データ5万枚、テスト用データ1万枚で構成されています。
  • ライブラリとして簡単に呼び出せるみたいなので、お試しがてら使ってみました。

環境

  • 主なものだけ記載します。
  • 足りない部分はpipやHomebrewからインストールしてください。
# OS/ソフトウェア/ライブラリ バージョン
1 Mac OS X EI Capitan
2 Python 3.6系
3 NumPy 1.13系
4 Keras 1.2系

実装

train.py
#!/usr/local/bin/python3
#!-*- coding: utf-8 -*-

import os
import numpy as np
from scipy.misc import toimage
import matplotlib.pyplot as plt

from keras.datasets import cifar10
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils.visualize_util import plot

def cnn_model(X_train, nb_classes):

    model = Sequential()

    model.add(Convolution2D(32, 3, 3, border_mode='same', input_shape=X_train.shape[1:]))
    model.add(Activation('relu'))
    model.add(Convolution2D(32, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Convolution2D(64, 3, 3, border_mode='same'))
    model.add(Activation('relu'))
    model.add(Convolution2D(64, 3, 3))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))

    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=['accuracy']
    )

    return model

def plot_cifar10(X, y, result_dir):
    plt.figure()

    # 画像を描画
    nclasses = 10
    pos = 1

    for targetClass in range(nclasses):
        targetIdx = []

        # クラスclassIDの画像のインデックスリストを取得
        for i in range(len(y)):
            if y[i][0] == targetClass:
                targetIdx.append(i)

        # 各クラスからランダムに選んだ最初の10個の画像を描画
        np.random.shuffle(targetIdx)

        for idx in targetIdx[:10]:
            img = toimage(X[idx])
            plt.subplot(10, 10, pos)
            plt.imshow(img)
            plt.axis('off')
            pos += 1

    plt.savefig(os.path.join(result_dir, 'plot.png'))

def save_history(history, result_file):
    loss = history.history['loss']
    acc = history.history['acc']
    val_loss = history.history['val_loss']
    val_acc = history.history['val_acc']
    nb_epoch = len(acc)

    with open(result_file, "w") as fp:
        fp.write("epoch\tloss\tacc\tval_loss\tval_acc\n")
        for i in range(nb_epoch):
            fp.write("%d\t%f\t%f\t%f\t%f\n" % (i, loss[i], acc[i], val_loss[i], val_acc[i]))

if __name__ == '__main__':

    # データ学習訓練の試行回数
    nb_epoch = 100

    # 学習結果の保存ディレクトリ
    result_dir = '{保存ディレクトリパス}'

    # 1回の学習で何枚の画像を使うか
    batch_size = 128

    # 識別ラベル数
    nb_classes = 10

    # 入力画像の次元
    img_rows, img_cols = 32, 32

    # チャネル数(RGBなので3)
    img_channels = 3

    # CIFAR-10データをロード
    # (nb_samples, nb_rows, nb_cols, nb_channel) = tf
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

    # ランダムに画像をプロット
    plot_cifar10(X_train, y_train, result_dir)

    # 画素値を0-1に変換
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255.0
    X_test /= 255.0

    # クラスラベル(0-9)をone-hotエンコーディング形式に変換
    Y_train = np_utils.to_categorical(y_train, nb_classes)
    Y_test = np_utils.to_categorical(y_test, nb_classes)

    # モデル
    model = cnn_model(X_train, nb_classes)

    # モデルのサマリを表示
    model.summary()
    plot(model, show_shapes=True, to_file=os.path.join(result_dir, 'model.png'))

    # 学習
    history = model.fit(
        X_train,
        Y_train,
        batch_size=batch_size,
        nb_epoch=nb_epoch,
        verbose=1,
        validation_data=(X_test, Y_test),
        shuffle=True
    )

    # 学習したモデルと重みと履歴の保存
    model_json = model.to_json()

    with open(os.path.join(result_dir, 'model.json'), 'w') as json_file:
        json_file.write(model_json)

    model.save_weights(os.path.join(result_dir, 'model.h5'))
    save_history(history, os.path.join(result_dir, 'history.txt'))

    # モデルの評価
    loss, acc = model.evaluate(X_test, Y_test, verbose=0)

    print('Test loss:', loss)
    print('Test acc:', acc)

学習

  • 実にわかりやすいUI。
Using TensorFlow backend.
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
convolution2d_1 (Convolution2D)  (None, 32, 32, 32)    896         convolution2d_input_1[0][0]      
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 32, 32, 32)    0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 30, 30, 32)    9248        activation_1[0][0]               
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 30, 30, 32)    0           convolution2d_2[0][0]            
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 15, 15, 32)    0           activation_2[0][0]               
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 15, 15, 32)    0           maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 15, 15, 64)    18496       dropout_1[0][0]                  
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 15, 15, 64)    0           convolution2d_3[0][0]            
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D)  (None, 13, 13, 64)    36928       activation_3[0][0]               
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 13, 13, 64)    0           convolution2d_4[0][0]            
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D)    (None, 6, 6, 64)      0           activation_4[0][0]               
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 6, 6, 64)      0           maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 2304)          0           dropout_2[0][0]                  
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 512)           1180160     flatten_1[0][0]                  
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 512)           0           dense_1[0][0]                    
____________________________________________________________________________________________________
dropout_3 (Dropout)              (None, 512)           0           activation_5[0][0]               
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 10)            5130        dropout_3[0][0]                  
____________________________________________________________________________________________________
activation_6 (Activation)        (None, 10)            0           dense_2[0][0]                    
====================================================================================================
Total params: 1,250,858
Trainable params: 1,250,858
Non-trainable params: 0
____________________________________________________________________________________________________
Train on 50000 samples, validate on 10000 samples
Epoch 1/100
2017-08-29 11:55:15.047457: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-29 11:55:15.047489: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-08-29 11:55:15.047498: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-08-29 11:55:15.047505: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
50000/50000 [==============================] - 400s - loss: 1.8624 - acc: 0.3123 - val_loss: 1.5157 - val_acc: 0.4537
Epoch 2/100
50000/50000 [==============================] - 455s - loss: 1.4213 - acc: 0.4825 - val_loss: 1.2090 - val_acc: 0.5645
Epoch 3/100
50000/50000 [==============================] - 384s - loss: 1.2453 - acc: 0.5549 - val_loss: 1.1120 - val_acc: 0.6004
Epoch 4/100
50000/50000 [==============================] - 414s - loss: 1.1602 - acc: 0.5869 - val_loss: 1.0376 - val_acc: 0.6304
Epoch 5/100
40960/50000 [=======================>......] - ETA: 64s - loss: 1.0852 - acc: 0.6156 

学習結果

  • 最終的には0.795500となかなかの数値
epoch   loss    acc val_loss    val_acc
0   1.862381    0.312260    1.515693    0.453700
1   1.421265    0.482520    1.209045    0.564500
2   1.245348    0.554880    1.111982    0.600400
3   1.160202    0.586880    1.037605    0.630400
4   1.082496    0.616840    1.002439    0.646900
5   1.014772    0.638600    0.921932    0.676100
6   0.968752    0.657040    0.890589    0.690000
7   0.933635    0.667640    0.854735    0.701200
8   0.893519    0.686920    0.841952    0.705100
9   0.866928    0.692880    0.832813    0.706700
10  0.835548    0.706200    0.808762    0.713600
11  0.811907    0.714380    0.790619    0.721300
12  0.787855    0.723800    0.769006    0.730800
13  0.769309    0.728920    0.757458    0.733200
14  0.749378    0.735600    0.731948    0.745700
15  0.729672    0.741200    0.707380    0.753900
16  0.711307    0.747860    0.720168    0.749600
17  0.694116    0.753920    0.690921    0.760300
18  0.685302    0.756300    0.688788    0.762900
19  0.670361    0.763320    0.699074    0.758100
20  0.655677    0.765700    0.683461    0.762800
21  0.650073    0.772340    0.675608    0.764000
22  0.637535    0.772880    0.705318    0.756400
23  0.627683    0.778300    0.659172    0.769500
24  0.608234    0.784560    0.649584    0.777400
25  0.605251    0.785540    0.670669    0.769400
26  0.604372    0.784440    0.657858    0.771600
27  0.591194    0.790200    0.648238    0.777000
28  0.583359    0.792300    0.644217    0.779500
29  0.579536    0.794680    0.654746    0.779200
30  0.577020    0.796780    0.650456    0.776600
31  0.558436    0.801840    0.654724    0.771200
32  0.551728    0.802580    0.660091    0.776500
33  0.551879    0.806080    0.691178    0.762100
34  0.551185    0.804320    0.643525    0.776900
35  0.537961    0.808680    0.648304    0.776100
36  0.538601    0.809580    0.663515    0.776100
37  0.526395    0.812180    0.667506    0.773200
38  0.520259    0.814020    0.637330    0.780900
39  0.520766    0.813460    0.626052    0.782500
40  0.513462    0.818240    0.637447    0.782200
41  0.508732    0.819540    0.628536    0.784700
42  0.512045    0.816880    0.648108    0.780600
43  0.502967    0.819200    0.681984    0.774500
44  0.502752    0.819600    0.626435    0.788800
45  0.506435    0.818640    0.641397    0.782800
46  0.492950    0.822220    0.643260    0.783700
47  0.490653    0.824700    0.652792    0.782700
48  0.479054    0.828640    0.641824    0.783700
49  0.483628    0.825200    0.636701    0.789600
50  0.483054    0.828940    0.635811    0.786600
51  0.472693    0.831800    0.628785    0.788600
52  0.473315    0.830320    0.630551    0.791700
53  0.466273    0.833520    0.636630    0.787100
54  0.466997    0.834320    0.627516    0.792000
55  0.467793    0.831500    0.641428    0.783800
56  0.456283    0.837320    0.635762    0.789500
57  0.452593    0.837640    0.636226    0.785800
58  0.459387    0.834740    0.628133    0.789200
59  0.450392    0.838600    0.651949    0.786700
60  0.446786    0.840840    0.628699    0.790300
61  0.443291    0.838820    0.635795    0.790300
62  0.445090    0.841860    0.630603    0.790400
63  0.438771    0.842060    0.618292    0.793700
64  0.444951    0.840980    0.631328    0.791900
65  0.443358    0.841860    0.631681    0.784200
66  0.436495    0.842280    0.640644    0.787600
67  0.434661    0.844800    0.625743    0.789400
68  0.431563    0.843140    0.633843    0.790500
69  0.430192    0.847000    0.642871    0.788800
70  0.421411    0.847940    0.624006    0.790100
71  0.423439    0.848720    0.650200    0.786900
72  0.420324    0.848240    0.654864    0.789300
73  0.421919    0.848480    0.627082    0.792700
74  0.410431    0.851180    0.642263    0.790100
75  0.419571    0.850860    0.636891    0.789100
76  0.413792    0.852300    0.653029    0.786400
77  0.415705    0.852520    0.644457    0.787400
78  0.423684    0.850840    0.631995    0.791600
79  0.410283    0.852720    0.652608    0.788600
80  0.411928    0.851480    0.634341    0.793000
81  0.405703    0.855640    0.630345    0.794300
82  0.408496    0.853320    0.652739    0.789500
83  0.402927    0.855320    0.642568    0.793900
84  0.404812    0.854160    0.637601    0.795100
85  0.400238    0.857520    0.633781    0.793700
86  0.403761    0.853880    0.625709    0.794700
87  0.404491    0.856420    0.635942    0.796500
88  0.392304    0.860800    0.636456    0.797500
89  0.398910    0.857180    0.655635    0.787900
90  0.392418    0.858680    0.634033    0.793500
91  0.390591    0.861520    0.627206    0.794700
92  0.385502    0.860740    0.657583    0.787700
93  0.391550    0.860260    0.664002    0.787900
94  0.388638    0.859500    0.648921    0.792000
95  0.385543    0.863360    0.644680    0.793800
96  0.386466    0.862720    0.648885    0.791200
97  0.384559    0.862300    0.651030    0.794400
98  0.384592    0.861200    0.629053    0.794900
99  0.388739    0.860320    0.634723    0.795500
9
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
9