Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@asakbiz

【ディープラーニング初心者向け】Kerasを使った全結合による簡単な二値分類の実装

More than 1 year has passed since last update.

概要

ディープラーニングの登場により、これまでの機械学習よりも良い精度でAIタスクをこなすことができるようになりました。
しかしながらディープラーニングはまだ発展段階ということもあり、こうすればよいと言った方法が確立されているわけではありません。また研究段階ということもあり、実装・制御が複雑なものも多い状況です。

今回は実装を出来る限り簡潔にし、ディープラーニングを動かしてみることを目的として、簡単な二値分類を紹介します。

今回は、リンゴとミカンの2つの画像をディープラーニングを使って分類します。DLフレームワークには、チューニングに難ありではあるものの簡単に使えるKerasを採用しました。

環境

  • gpu GeForce GTX 1070
  • os ubuntu16.04
  • CUDA 8.0
  • cudnn 6.0
  • Keras 2.0.8

ディレクトリ構成

実行モジュールとデータ格納用のディレクトリを準備します

├── data 
└── exe.py

dataディレクトリの中には学習用データtrainとテスト用データtestを準備し、リンゴとミカンの画像をそれぞれ格納します。

├── test
│   ├── 00_apple
│   └── 01_orange
└── train
    ├── 00_apple
    └── 01_orange

データ

各ディレクトリの中はこのようになっています。

  • リンゴ
    20171125235955.png

  • ミカン
    20171126000023.png

学習

from keras.utils.np_utils import to_categorical
from keras.optimizers import Adagrad
from keras.optimizers import Adam
import numpy as np
from PIL import Image
import os

# 教師データ読み込み

train_path="./data/train/"
test_path="./data/test/"

xsize=25
ysize=25

image_list = []
label_list = []


for dataset_name in os.listdir(train_path):

    dataset_path = train_path + dataset_name
    label = 0

    if dataset_name == "00_apple":
        label = 0
    elif dataset_name == "01_orange":
        label = 1

    for file_name in sorted(os.listdir(dataset_path)):
        label_list.append(label)
        file_path = dataset_path + "/" + file_name
        image = np.array(Image.open(file_path).resize((xsize, ysize)))
        print(file_path)

        # RGBの順に変換、[[Redの配列],[Greenの配列],[Blueの配列]]
        image = image.transpose(2, 0, 1)

        # 1次元配列に変換(25*25*3) Red,Green,Blueの要素が順番に並ぶ。
        image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0]

        # 0〜1の範囲に変換
        image_list.append(image / 255.)


# numpy変換。
X = np.array(image_list)

# label=0 -> [1,0], label=1 -> [0,1] に変換
Y = to_categorical(label_list)

# モデル定義
model = Sequential()
model.add(Dense(200, input_dim=xsize*ysize*3))
model.add(Activation("relu"))
model.add(Dropout(0.2))

model.add(Dense(200))
model.add(Activation("relu"))
model.add(Dropout(0.2))

model.add(Dense(2))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=0.001), metrics=["accuracy"])
model.summary()

# 学習
model.fit(X, Y, nb_epoch=1000, batch_size=100, validation_split=0.1)

画像分類などを行う際にはCNNを用いることが多いですが、今回は単純化のため、全結合のみを使用しました。また、画像から形の特徴量を抽出する際に、不要な情報が含まれないようにグレースケール化することが多いですが、リンゴとミカンの分類であることから色の情報が重要だと判断し、グレースケール化せずRGBの情報をすべてをニューラルネットのinputに渡しています。

推論

# 推論
total = 0.
ok_count = 0.

for testset_name in os.listdir(test_path):

    testset_path = test_path + testset_name

    label = -1

    if testset_name == "00_apple":
        label = 0
    elif testset_name == "01_orange":
        label = 1
    else:
        print("error : label not exist")
        exit()

    for file_name in os.listdir(testset_path):
        label_list.append(label)
        file_path = testset_path + "/" + file_name
        image = np.array(Image.open(file_path).resize((25, 25)))
        print(file_path)
        image = image.transpose(2, 0, 1)
        image = image.reshape(1, image.shape[0] * image.shape[1] * image.shape[2]).astype("float32")[0]
        result = model.predict_classes(np.array([image / 255.]))
        print("label:", label, "result:", result[0])

        total += 1.

        if label == result[0]:
            ok_count += 1



print("accuracy: ", ok_count / total * 100, "%")

結果

accuracy:  100.0 %

ディープラーニングが高精度と言えども正解率100%とまではなかなかなりません。しかしながら今回はリンゴとミカンの単純な分類であるため100%の正解率となりました。

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
asakbiz
機械学習/ディープラーニングに関するコンサルやってます。好きな分野は自然言語処理やweb分析。 主な使用言語はPython。Javascript,特にVue.js勉強中。リブコード所属。

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
1
Help us understand the problem. What is going on with this article?