LoginSignup
0
0

More than 5 years have passed since last update.

Keras+Tensorflow で猫の毛色分類(サビ猫は雑誌の表紙を飾ったのか?)3

Last updated at Posted at 2018-09-04

もくじ

  1. 学習用画像を集める
  2. 猫のご尊顔を検出
  3. 学習(ココ)
  4. 猫雑誌の表紙画像を集めて判定

学習

上記の記事に習い、集めた画像をカテゴリ付きでnumpyの配列にして保存しておく。

data.py
from PIL import Image
import sys
import os, glob
import numpy as np
import random, math

def create(input_dir) :
    Image.LOAD_TRUNCATED_IMAGES = True
    categorys = []
    dir_list = os.listdir(input_dir)
    for index, dir_name in enumerate(dir_list):
      if dir_name == '.DS_Store' :
        continue
      categorys.append(dir_name)
    image_size = 160
    train_data = []
    for idx, category in enumerate(categorys):
      try :
        print("---", category)
        image_dir = input_dir + "/" + category
        files = glob.glob(image_dir + "/*.jpg")
        for i, f in enumerate(files):
            img = Image.open(f)
            img = img.convert("RGB")
            img = img.resize((image_size, image_size))
            data = np.asarray(img)
            train_data.append([data, idx])
      except:
        print("SKIP : " + category)

    random.shuffle(train_data)
    X, Y = [],[]
    for data in train_data:
      X.append(data[0])
      Y.append(data[1])
    print(len(X))
    test_idx = math.floor(len(X) * 0.8)
    print(test_idx)
    xy = (np.array(X[0:test_idx]), np.array(X[test_idx:]),
          np.array(Y[0:test_idx]), np.array(Y[test_idx:]))
    np.save("cat_color", xy)

if __name__ == "__main__":
  args = sys.argv
  input_dir = args[1]  #各カテゴリ画像データの上位ディレクトリ
  create(input_dir)

学習部分は、ImageDataGeneratorで水増しなど採用した。
顔ギリで切ってるので、あまりいろいろ動かさず(0.1)
vertical_flip を付けたら感度ダダ下がり、顔には向いてない

training.py
import sys
import os
import numpy as np
from keras.models import Sequential, model_from_json
from keras.callbacks import ModelCheckpoint
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.utils import np_utils
from keras.preprocessing.image import ImageDataGenerator

def train(input_dir,npy_file) :
    nb_classes = len([name for name in os.listdir(input_dir) if name != ".DS_Store"])
    x_train, x_test, y_train, y_test = np.load(npy_file)
    x_train = x_train.astype("float") / 255.
    x_test = x_test.astype("float") / 255.
    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test = np_utils.to_categorical(y_test, nb_classes)
    train_datagen = ImageDataGenerator(   #画像回転、ズームなどで水増し...
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode='nearest')

    model = Sequential()
    model.add(Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]))
    model.add(Activation('relu'))

    model.add(Conv2D(32, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(512))

    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

    # 学習してベストモデルを保存
    checkpoint_cb = ModelCheckpoint("snapshot/cat-bestmodel.hdf5", save_best_only=True)
    result = model.fit_generator(train_datagen.flow(x_train,  y_train,
                    batch_size=64, save_to_dir=None),
                    epochs=50,
                    steps_per_epoch=None,
                    validation_data=(x_test, y_test),
                    validation_steps=None,
                    workers=4,
                    callbacks=[checkpoint_cb])

    # ベストモデルのテスト
    model.load_weights("snapshot/cat-bestmodel.hdf5")
    score = model.evaluate(x_test, y_test)
    print('loss=', score[0])
    print('accuracy=', score[1])

if __name__ == "__main__":
  args = sys.argv
  input_dir = args[1]
  npy_file = args[2]  #data.pyで保存したcat_color.npy 
  train(input_dir, npy_file)

loss= 0.47945548266899296
accuracy= 0.885569990659696
まあまあだね!(デリアンの口調で。2出ないかな)
 テストデータの数がちょうど3,700になったので、batch_size=50にして、1エポック74ステップ、Generator てどうやって動いてるんだろ??
 フィルタの数やら、プーリング、ドロップアウトの回数やら、関数何を使おうか? とかは、全然比較検討してないです...

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0