Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@asakbiz

【ディープラーニング】Kerasを使ってサッカー選手の顔画像を分類してみる

More than 1 year has passed since last update.

はじめに

今更ではあるけども、ディープラーニングを使って簡単な顔画像の分類を行ってみた。前回4人のサッカー選手で顔画像の分類を行っており、今回は発展させて、10人で分類してみる。

環境

  • python3
  • ubuntu 16.04
  • GeForce GTX 1070
  • CUDA 8.0
  • cudnn 6.0
  • keras 2.0.8

対象画像

世界で活躍する以下のサッカー選手の画像を対象とした。
fbplayer.png

※上段左からmessi,james,neymar,suarez,ronaldo,
※下段左からroony,nodric,lewandowski,bruyne,persie

事前にbingAPIで画像を収集し、openCVを使って顔部分の切り抜きを行っておく。切り取った画像はdatasetというディレクトリ配下にtrain(学習用)50枚、test(推論用)10枚に分けて格納しておく。

dataset
├── test
│   ├── 00_messi
│   ├── 01_james
│   ├── 02_neymar
│   ├── 03_suarez
│   ├── 04_ronaldo
│   ├── 05_roony
│   ├── 06_modric
│   ├── 07_lewandowski
│   ├── 08_bruyne
│   ├── 09_persie
├── train
    ├── 00_messi
    ├── 01_james
    ├── 02_neymar
    ├── 03_suarez
    ├── 04_ronaldo
    ├── 05_roony
    ├── 06_modric
    ├── 07_lewandowski
    ├── 08_bruyne
    ├── 09_persie

各ディレクトリの中身は以下のイメージ。
faces.png

Data

上記で準備したサッカー選手の画像を以下のように読み込む。なお今回はデータオーグメンテーションは行わない。

class DataGen():

    def __init__(self):
        pass

    def generate(self):
        train_label_list = []
        train_image_list = []
        test_label_list =[]
        test_image_list =[]

        for traindir_name in  sorted(os.listdir(train_dirpath)):
            # 対象以外は除外
            try:
                label=label_dict[traindir_name]
            except:
                print("label not defined:"+traindir_name)
                continue

            # 対象以外は除外
            trainfile_path = train_dirpath + traindir_name
            for file_name in sorted(os.listdir(trainfile_path)):
                train_label_list.append(label)
                file_path = trainfile_path + "/" + file_name
                image = np.array(Image.open(file_path).resize((imgxsize, imgysize)))
                print("file_path:",file_path)

                # 0〜1の範囲に変換
                train_image_list.append(image / 255.)
        # 推論
        for testfile_name in  sorted(os.listdir(test_dirpath)):

            testfile_path = test_dirpath + testfile_name

            try:
                label=label_dict[testfile_name]
            except:
                print("label not defined:"+testfile_name)
                continue

            for file_name in  sorted(os.listdir(testfile_path)):
                test_label_list.append(label)
                file_path = testfile_path + "/" + file_name
                image = np.array(Image.open(file_path).resize((imgxsize, imgysize)))
                test_image_list.append(image)
                print("file_path:",file_path)

        return train_image_list,train_label_list,test_image_list,test_label_list

model

CNN(Convolutional Neural Network)を使用してConvNetを作成する。
convolutionでは、画像処理でよく利用される手法で、カーネル(またはフィルター)を画像に適用することで、その画像の特徴量を抽出する。例えば犬の画像であれば、カーネルを適用することで画像から'犬らしさ'を数字列(ベクトル、テンソル)で取り出す。数字列を取り出したら、全結合層に値を入力し、分類を行う。

class Model():

    def __init__(self):
        self.loss="categorical_crossentropy"
        self.lr=0.001
        pass

    def define(self,X,Y):
        model = Sequential()

        model.add(Conv2D(32, (3, 3), padding='same',input_shape=X.shape[1:]))
        model.add(Activation('relu'))
        model.add(Conv2D(32, (3, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(64, (3, 3), padding='same'))
        model.add(Activation('relu'))
        model.add(Conv2D(64, (3, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        model.add(Dense(len(label_dict)))
        model.add(Activation('softmax'))

        model.compile(loss=self.loss, optimizer=Adam(lr=self.lr), metrics=["accuracy"])
        model.summary()

        return model


train&test

以下を実行することで、学習および推論をおこなう。

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils.np_utils import to_categorical
from keras.optimizers import Adam
import numpy as np
from PIL import Image
import os

#定数
imgxsize=25
imgysize=25
epoch=200
validation_split=0.1
batch_size=50
#label_dict={"00_messi":0,"01_james":1,"02_neymar":2,"03_suarez":3}
#label_mean_dict={0:"messi",1:"james",2:"neymar",3:"suarez"}
label_dict={"00_messi":0,"01_james":1,"02_neymar":2,"03_suarez":3,
            "04_ronaldo":4,"05_roony":5,"06_modric":6,
            "07_lewandowski":7,"08_bruyne":8,"09_persie":9}
label_mean_dict={0:"messi",1:"james",2:"neymar",3:"suarez",
                4:"ronaldo",5:"roony",6:"modric",
                7:"lewandowski",8:"bruyne",9:"persie"}
train_dirpath="./dataset/train/"
test_dirpath="./dataset/test/"

if __name__=="__main__":
    # データ
    train_image_list,train_label_list,test_image_list,test_label_list=DataGen().generate()
    X = np.array(train_image_list)
    Y = to_categorical(train_label_list)

    # 学習
    model=Model().define(X,Y)
    model.fit(X, Y, nb_epoch=epoch, batch_size=batch_size, validation_split=validation_split)

    #推論
    total = 0.
    ok_count = 0.
    for image,label in zip(test_image_list,test_label_list):
        result = model.predict_classes(np.array([image / 255.]))
        print("correct:", label_mean_dict[label], "  result:", label_mean_dict[result[0]])

        total += 1.

        if label == result[0]:
            ok_count += 1.

    print("accuracy: ", ok_count / total * 100, "%")

結果

accuracy:  69.0 %

約70%の正解率。特別チューニングを行わないとcifar10で80%程度であるため、まずまずの精度かと思われる。

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
asakbiz
機械学習/ディープラーニングに関するコンサルやってます。好きな分野は自然言語処理やweb分析。 主な使用言語はPython。Javascript,特にVue.js勉強中。リブコード所属。

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
1
Help us understand the problem. What is going on with this article?