30
30

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

初めてのDeep Learning 〜奮闘編〜

Last updated at Posted at 2015-12-22

皆さんこんにちは。@best_not_bestです。
この記事は初めてのDeep Learning 〜準備編〜の続きになります。未読の方は是非そちらを先にお読みください。

注意

本記事は個人的欲望から生み出されたものであり、所属する組織の公式見解ではございません。

実践

3. 学習用画像(好きな芸能人)を集める

Google画像検索APIを使おうと思ったのですが、以下の記事によると使用制限が厳しいようです。
Web画像検索で画像を収集する - のどあめ

Bing Search APIを使用します。ソースコードも上記記事を参考にさせていただきました。

get_images.py
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import sys
import os
import requests
import urllib
import urllib2
import json

BING_URL = 'https://api.datamarket.azure.com/Bing/Search/Image?'
MS_ACCTKEY = 'hogehoge'
QUERY = '好きな芸能人の名前'
OUTPUT_DIR_PATH = './talent/'

opener = urllib2.build_opener()
urllib2.install_opener(opener)

def download_urllist(urllist):
    for url in urllist:
        try:
            print 'url: ' + url
            with open(OUTPUT_DIR_PATH + '/' + os.path.basename(url), 'wb') as f:
                img = urllib2.urlopen(url, timeout = 5).read()
                f.write(img)
        except urllib2.URLError:
            print('URLError')
        except IOError:
            print('IOError')
        except UnicodeEncodeError:
            print('EncodeError')
        except OSError:
            print('OSError')

if __name__ == "__main__":
    query = urllib2.quote(QUERY)
    step = 20
    num = 50

    url_param_dict = {
        'Query': "'"+QUERY+"'",
        'Market': "'ja-JP'",
    }
    url_param_base = urllib.urlencode(url_param_dict)
    url_param_base = url_param_base + "&$format=json&$top=%d&$skip="%(num)

    for skip in range(0, num*step, num):
        url_param = url_param_base + str(skip)
        url = BING_URL + url_param
        print url

        response = requests.get(url,
                                auth = (MS_ACCTKEY, MS_ACCTKEY),
                                headers = {'User-Agent': 'My API Robot'})
        response = response.json()

        urllist = [item['MediaUrl'] for item in response['d']['results']]
        download_urllist(urllist)

以下のコマンドで実行します。

$ python get_images.py

MS_ACCTKEYがAzure Market Placeのプライマリーアカウントキー、QUERYが検索したい文字列、OUTPUT_DIR_PATHが切り出した取得したファイルの格納先ディレクトリになりますので、適宜選択ください。
これで50件×20ページ分、約1,000件の画像が取得できます。やったね!

4. 学習用画像の顔部分を切り抜く

前回のcutout_face.pyをそのまま使おうと思ったのですが、ウェブから集めて来た画像にはイレギュラーなものも含まれているため、切り抜き時に例外処理を追加しました。ついでに、切り出した画像ファイル名に日本語が含まれているのが何か嫌だったので、連番がファイル名になるようにしています。

cutout_talent_face.py
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import numpy
import os
import cv2

CASCADE_PATH = '/usr/local/opt/opencv/share/OpenCV/haarcascades/haarcascade_frontalface_alt.xml'
INPUT_DIR_PATH = './talent/'
OUTPUT_DIR_PATH = './talent_cutout/'
OUTPUT_FILE_FMT = '%s%d_%d%s'

count = 1
files = os.listdir(INPUT_DIR_PATH)
for file in files:
    input_image_path = INPUT_DIR_PATH + file

    # ファイル読み込み
    image = cv2.imread(input_image_path)
    # グレースケール変換
    try:
        image_gray = cv2.cvtColor(image, cv2.cv.CV_BGR2GRAY)
    except cv2.error:
        continue

    # カスケード分類器の特徴量を取得する
    cascade = cv2.CascadeClassifier(CASCADE_PATH)

    # 物体認識(顔認識)の実行
    facerect = cascade.detectMultiScale(image_gray, scaleFactor=1.1, minNeighbors=1, minSize=(1, 1))

    if len(facerect) > 0:
        # 認識結果の保存
        i = 1
        for rect in facerect:
            print rect
            x = rect[0]
            y = rect[1]
            w = rect[2]
            h = rect[3]

            path, ext = os.path.splitext(os.path.basename(file))
            output_image_path = OUTPUT_FILE_FMT % (OUTPUT_DIR_PATH, count, i, ext)
            i += 1
            try:
                cv2.imwrite(output_image_path, image[y:y+h, x:x+w])
            except cv2.error:
                print file
                continue

    count += 1

使い方はcutout_face.pyと同様です。

5. Python + Chainerで4.を学習させ、判別器を作成

以下を参考にさせていただきました。
ChainerのNINで自分の画像セットを深層学習させて認識させる - shi3zの長文日記
ディープラーニングでおそ松さんの六つ子は見分けられるのか 〜実施編〜 - bohemia日記

ソースをcloneします。

[work_dir]$ git clone https://github.com/shi3z/chainer_imagenet_tools.git

ImageNetのソースもcloneしてきます。

[work_dir]$ git clone https://github.com/pfnet/chainer.git
[work_dir]$ cd chainer
[chainer]$ git checkout -b 1.4.1 refs/tags/v1.4.1
[chainer]$ cd ..

Caltech 101をダウンロードして展開します

[work_dir]$ cd chainer_imagenet_tools
[chainer_imagenet_tools]$ wget http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz
[chainer_imagenet_tools]$ tar xzvf 101_ObjectCategories.tar.gz

101_ObjectCategoriesディレクトリに、4.で切り抜いた画像が入っているディレクトリを入れます。また、ディレクトリ数が多いと学習に時間が掛かるのでいくつかディレクトリを削除しています。101_ObjectCategoriesディレクトリの構成は以下の通りです。
パンダかわいい。

drwxr-xr-x  472 talent_cutout
drwxr-xr-x  438 Faces_easy
drwxr-xr-x  438 Faces
drwxr-xr-x   41 panda

train.txt、test.txt、label.txtを作成します。

[chainer_imagenet_tools]$ python make_train_data.py 101_ObjectCategories

画像サイズを256×256にします。

[chainer_imagenet_tools]$ python crop.py images/ images/

mean.npyを生成します。

[chainer_imagenet_tools]$ python ../chainer/examples/imagenet/compute_mean.py train.txt

学習スタート!

[chainer_imagenet_tools]$ python ../chainer/examples/imagenet/train_imagenet.py -g -1 -E 20000 train.txt test.txt 2>&1 | tee log

・・・。

train 3315 updates (106080 samples) time: 14:22:12.501801 (2.05054842326 images/sec)epoch 110

14時間経過しても終わらない(´;ω;`)

次回

今回はここまでです。(勉強して出直してきます。)
最終話(年内掲載予定)に続きます!!

追記

解決しました!
初めてのDeep Learning 〜解決編〜 - Qiita

30
30
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
30

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?