Help us understand the problem. What is going on with this article?

【ディープラーニング初心者向け】Kerasを使った全結合による簡単な二値分類の実装

概要

ディープラーニングの登場により、これまでの機械学習よりも良い精度でAIタスクをこなすことができるようになりました。
しかしながらディープラーニングはまだ発展段階ということもあり、こうすればよいと言った方法が確立されているわけではありません。また研究段階ということもあり、実装・制御が複雑なものも多い状況です。

今回は実装を出来る限り簡潔にし、ディープラーニングを動かしてみることを目的として、簡単な二値分類を紹介します。

今回は、リンゴとミカンの2つの画像をディープラーニングを使って分類します。DLフレームワークには、チューニングに難ありではあるものの簡単に使えるKerasを採用しました。

環境

  • gpu GeForce GTX 1070
  • os ubuntu16.04
  • CUDA 8.0
  • cudnn 6.0
  • Keras 2.0.8

ディレクトリ構成

実行モジュールとデータ格納用のディレクトリを準備します

├── data 
└── exe.py

dataディレクトリの中には学習用データtrainとテスト用データtestを準備し、リンゴとミカンの画像をそれぞれ格納します。

├── test
│   ├── 00_apple
│   └── 01_orange
└── train
    ├── 00_apple
    └── 01_orange

データ

各ディレクトリの中はこのようになっています。

  • リンゴ
    20171125235955.png

  • ミカン
    20171126000023.png

学習

from keras.utils.np_utils import to_categorical
from keras.optimizers import Adagrad
from keras.optimizers import Adam
import numpy as np
from PIL import Image
import os

# 教師データ読み込み

train_path="./data/train/"
test_path="./data/test/"

xsize=25
ysize=25

image_list = []
label_list = []


for dataset_name in os.listdir(train_path):

    dataset_path = train_path + dataset_name
    label = 0

    if dataset_name == "00_apple":
        label = 0
    elif dataset_name == "01_orange":
        label = 1

    for file_name in sorted(os.listdir(dataset_path)):
        label_list.append(label)
        file_path = dataset_path + "/" + file_name
        image = np.array(Image.open(file_path).resize((xsize, ysize)))
        print(file_path)

        # RGBの順に変換、[[Redの配列],[Greenの配列],[Blueの配列]]
        image = image.transpose(2, 0, 1)

        # 1次元配列に変換(25*25*3) Red,Green,Blueの要素が順番に並ぶ。
        image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0]

        # 0〜1の範囲に変換
        image_list.append(image / 255.)


# numpy変換。
X = np.array(image_list)

# label=0 -> [1,0], label=1 -> [0,1] に変換
Y = to_categorical(label_list)

# モデル定義
model = Sequential()
model.add(Dense(200, input_dim=xsize*ysize*3))
model.add(Activation("relu"))
model.add(Dropout(0.2))

model.add(Dense(200))
model.add(Activation("relu"))
model.add(Dropout(0.2))

model.add(Dense(2))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=0.001), metrics=["accuracy"])
model.summary()

# 学習
model.fit(X, Y, nb_epoch=1000, batch_size=100, validation_split=0.1)

画像分類などを行う際にはCNNを用いることが多いですが、今回は単純化のため、全結合のみを使用しました。また、画像から形の特徴量を抽出する際に、不要な情報が含まれないようにグレースケール化することが多いですが、リンゴとミカンの分類であることから色の情報が重要だと判断し、グレースケール化せずRGBの情報をすべてをニューラルネットのinputに渡しています。

推論

# 推論
total = 0.
ok_count = 0.

for testset_name in os.listdir(test_path):

    testset_path = test_path + testset_name

    label = -1

    if testset_name == "00_apple":
        label = 0
    elif testset_name == "01_orange":
        label = 1
    else:
        print("error : label not exist")
        exit()

    for file_name in os.listdir(testset_path):
        label_list.append(label)
        file_path = testset_path + "/" + file_name
        image = np.array(Image.open(file_path).resize((25, 25)))
        print(file_path)
        image = image.transpose(2, 0, 1)
        image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0]
        result = model.predict_classes(np.array([image / 255.]))
        print("label:", label, "result:", result[0])

        total += 1.

        if label == result[0]:
            ok_count += 1



print("accuracy: ", ok_count / total * 100, "%")

結果

accuracy:  100.0 %

ディープラーニングが高精度と言えども正解率100%とまではなかなかなりません。しかしながら今回はリンゴとミカンの単純な分類であるため100%の正解率となりました。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした