Help us understand the problem. What is going on with this article?

エロゲの立ち絵っぽいかの判定機を作った話

More than 1 year has passed since last update.

1 作ったもの

ノベルゲームの立ち絵かどうか、判定するものを作りました。
(ソシャゲのカードは含みません)

立ち絵として判定されたもの

tatie.PNG

立ち絵じゃないと判定されたもの

nottatie.PNG

実際にここで試せます
立ち絵かどうか判定するマン

2 前置き

文系が機械学習を試してみた体験談です。
技術的な内容というよりは、何を試したらダメで、どうしたら上手くいったか
何を勉強してどういう風にチューニングしていったかをまとめていきたいと思います

適当にググってみたけど……

・行列メッチャ使うらしい。習ってないし困った
・サンプル落としたけどまったく分からん。epochってなんだろう。
・ベクトル?矢印で書く奴だよね。対数?聞いたことはあるなぁ
・モデルとやらが必要らしい。どうやって作るんだろう

3 勉強する

amazonにある中でも、一番簡単そうな本を買って2週しました。

やさしく学ぶ 機械学習を理解するための数学のきほん アヤノ&ミオと一緒に学ぶ 機械学習の理論と数学、実装まで
かなり良著でした。
1割くらいは分かりませんでしたが、そういうところは概念だけ理解しました。
数学だけでなく、パーセプトロンとか機械学習の基本的なイメージも教えてくれるので
非常に助かりました

Chainer v2による実践深層学習
ChainerのVersionがちょっと古いですが、version3すら見つけられなかったので購入。
ディープラーニングのライブラリはいくつかありますが、Chainerにしました。
二次元関連の資料が多く、「PaintsChainer」等、実績もあるからです
画像による分類まで読み、判定機制作に着手しました。

4 事前準備 画像を集める

機械学習には学習用画像が数千枚必要です
手でやるのはあまりにもしんどいのでpinterestからスクレイピングして集めました。
以前作ったスクレイピングツール
横幅の方が大きかったり、背景色があまりないものはこの時点でnot立ち絵と判定して取得しないようにしています

pinterestにした理由ですが
・pixivは立ち絵がほとんどなく、クオリティにも差がある
・yahoo画像検索は量がいまいちでノイズも多い

pinterestなら、ユーザーが自らピン止め(ふぁぼみたいな)した画像のみなため
クオリティが高く、そこそこ枚数もあるからです

最適な画像が集まる用、検索条件をいろいろ変えて試しました。
ノベルゲームに出てくる立ち絵を想定しているため、ソシャゲのカード等はnot立ち絵にしています
男の立ち絵についてはなんも考えてないです

立ち絵として収集したもの

立ち絵 女,anime girl,ソシャゲ  女,トイズドライブ,グリムノーツ,スクエニ 女,英雄伝説 女,ファルコム 女,ガスト 女,SEGA 女,のアトリエ,ブレイブソード×ブレイズソウル,魔界戦記ディスガイア,テイルズ 女,レジェンヌ,ラスピリ,白猫プロジェクト,#白猫シェアハウス,エンドブレイカー,ファントムオブキル,ガールフレンド仮,エンドライド 女,Cygames 女

not立ち絵として収集したもの

キャラ,pixiv,トランプ,萌え,車,ドラゴン,ラブライブ,デレステ,刀剣乱舞,

全て100×100にリサイズし、正解画像については左右判定させて2倍に増やします
このあたり全部pythonで書いてます。
この時点で大体4000枚。

5 実際に書いてみる前に

いきなり4000枚で試すと大変なので400枚に減らします。
最初はwindowsでやってましたがGPUがなく時間のかかってしまいます
AWSも検討したけどくっそ高い。(GPU1時間800円くらい)

google colaboratoryを使います。
notebook形式で書けるGPUが使えるubuntu実行環境といえばいいのでしょうか。
一応python実行環境ですがapt-get使えるからなんでも入ります(PHPとかも)
spreadsheetみたく、googleDriveから簡単に利用出来ます。
少し手順が必要ですが、googleDriveをマウントできるので共有や実行もメッチャ楽。

無料なのが意味わかんないですが、せっかくなので酷使させていきます

一定時間でインスタンスが破壊されるなど、制約も多少あるので
こちらの記事を読んでおくことをおすすめします
【秒速で無料GPUを使う】深層学習実践Tips on Colaboratory

jupyter notebookは初めて触りましたが、30分くらい触れば大体わかるので
chainerとかの前に適当に遊びました

6 書いてみる

モデルの学習コードはこんな感じです
https://colab.research.google.com/drive/1kCg2nR49AMCtHCzPviiQrbCMc0QT0rDL

必要なものをapt-getし、
googleDriveを使えるようにマウント
そのあとにpythonのコードが続きます。
epoch毎にmodelのスナップショットを保存し、別のnotebookからテストして正解率を図っていきます。
boolingを変えたり、Iteratorの数、filterのサイズを変えてひたすら実験です。
50パターンくらいは実験したと思います。
この作業に5日くらいかかりました。

filterサイズを1変えるだけで正答率が10%変わったり、
80%をたたき出した次のepochからずっと50%代になったり
予測とは全然違う結果が出て終始驚いてました

テストも結構大変でした
最初は単純な正解率だけ計っていましたが、正解率70%でお?と思ったモデルが
全部を「立ち絵でない」と判定していたり。
先ほど紹介した書籍にもそういうことは書いてあったのですが、失念していました。

正解率80%を出した時点で、画像を400枚から4000枚にして全裸待機。
これは勝つる!と思い実行してみると正解率70%。何かがおかしい?

原因はノイズ画像でした。
先ほど使った400枚4000枚のうち、1番から200番までをそれぞれ選んだもの。
前半の方がスクレイピングの結果が良く、ラベル通り「立ち絵」「立ち絵ではない」になってましたが
それ以降は「立ち絵」とされているものがそうでなかったり。
20%くらいは間違っていたので、画像を全部見て手動で仕分け。
4000枚なので、まあ2時間くらいで終わりました。

スクレイピング結果が常に正しいとは限らない。
結局目で見ての仕分けは必要ですね。

ついでに3時間かけてノベルゲーム制作会社のHPを漁って立ち絵を取得
完全手作業。かなり虚無を感じました。
大体150枚ゲット。

この時点で
A pinterestから取得した立ち絵 2000枚
B pinterestから取得した立ち絵じゃないもの 2000枚
C 手動スクレイピングしたギャルゲ画像 150枚
が手元にあるので
AとCを左右反転させて倍増させ、Cはさらに2倍にしました。
リサイズ&左右反転機のコードはこんな感じです

import os
import sys
import glob
import re
from PIL import Image

#リサイズ元ディレクトリ名
in_folder_name = "origin"

#出力先ディレクトリ名
out_folder_name = "resize"

#リサイズ画像サイズ
save_size = 100
def returnFormat(format):
    if format == ".bmp":
        return "BMP"
    elif format == ".jpg":
        return "JPEG"
    elif format == ".jpeg":
        return "JPEG"
    elif format == ".JPG":
        return "JPEG"
    elif format == ".png":
        return "PNG"
    elif format == ".gif":
        return "GIF"
    else:
        print(format + " は対応していません。")
        sys.exit()


files = glob.glob(in_folder_name+'/*')

for f in files:
    origin = Image.open(f)

    origin_width, origin_height = origin.size
    #中央揃えするための変数
    x_pos = 0 
    # 縮小比率と中央揃えの計算
    if origin_height > save_size:
        width = origin_width / (origin_height / save_size)
        height = save_size
        x_pos = int((origin_height - origin_width) / 2)
    else:
        width = origin_width
        height = origin_height

    # パレットモードの解除用に、判定用レイヤーに張り付け
    judge_layer = Image.new('RGBA', (origin_width, origin_height), (0, 0, 0, 0))
    judge_layer.paste(origin, (0, 0))

    # 中央に揃えるため新規レイヤーに張り付け
    layer = Image.new('RGBA', (origin_height, origin_height), (0, 0, 0, 0))
    layer.paste(origin, (x_pos, 0))

    # 透過を白に塗りつぶす用。
    canvas = Image.new('RGB', layer.size, (255, 255, 255))
    for x in range(origin_height):
        if x < x_pos:
            continue
        if x > x_pos+origin_width:
            continue
        for y in range(origin_height):
            pixel = layer.getpixel((x, y))
            # 透過なら白に塗りつぶし
            if pixel[3] == 0 :
                canvas.putpixel((x, y), (255, 255, 255))
            else:
            # 透過以外なら、用意した画像にピクセルを書き込み
                canvas.putpixel((x, y), (pixel[0], pixel[1], pixel[2]))

    # ファイル名の拡張子より前を取得し, フォーマット後のファイル名に変更
    ftitle, format = os.path.splitext(f)
    file_name  = os.path.split(ftitle)

    resize = canvas.resize((save_size, save_size), Image.BICUBIC)
    # 画像の保存
    with open(out_folder_name+".txt", "a") as file:
        print(out_folder_name+"/"+file_name[1]+format+" 0", file = file)
    resize.save(out_folder_name+"/"+file_name[1]+format, returnFormat(format), quality=100, optimize=True)

    左右反転
    canvas_mirror = resize.transpose(Image.FLIP_LEFT_RIGHT)
    with open(out_folder_name+".txt", "a") as file:
        print(out_folder_name+"/"+file_name[1]+"_mirror"+format+" 1", file = file)
    canvas_mirror.save(out_folder_name+"/"+file_name[1]+"_mirror"+format, returnFormat(format), quality=100, optimize=True)

    print(out_folder_name+"/"+file_name[1]+format)

400枚で試したときに結果の良かったパターンを3つくらい抽出し、実験
96%の正解率まで行きました。

いい感じなのでさらに精度をあげます。
立ち絵じゃない画像として、pinterestから以下の4000枚を追加。(目視で立ち絵じゃないのは取り除きます)

「アニメ,フィギュア,コスプレ,ロボット,写真,仏像,スポーツ,ラーメン,ケーキ,フレンチ,モデル,家,AKB,声優,ニューヨーク,京都,ピカソ,浮世絵,肖像画,魚,犬,cat,シャドバ,イラスト屋,ボタン」

合計1万枚で再度モデルを作ります。
epoch86にして立ち絵は226/240、not立ち絵は736/741正解の
正解率98%のモデルが出来ました

webサービスとして公開する

判定機自体は出来ましたが、コマンドラインでpythonを叩いて実行する必要があります。
サービスとして公開したかったため、web化します。
具体的には画像をアップしたら判定、というよくあるアレです

サーバー自体は持っていますが、運用しているツールがいくつかあり
負荷的な意味で怖かったので
awsでサーバーを立ち上げ、API化して判定します。

フロント側でアップした画像をbase64にしてAPIに送信すると
立ち絵か立ち絵じゃないか返してくれる感じです。

pythonでwebアプリ作った事がなかったのでとりあえず簡単そうなFlaskを選択。
python3mod_wsgiなど必要なものを入れ、apache起動
(yummod_wsgi入れるとpytyon2用のが入るのでpipで入れましょう。ハマりました)

モデルの判定対象画像をPathによるファイル読み込みではなく、base64から生成するように修正。
判定機のコードはこんな感じです。

import sys
import os
from flask import Flask, make_response, jsonify, request
from flask_cors import CORS

import base64
from PIL import Image
from io import BytesIO


import numpy as np
import math
import chainer
from chainer import Function, \
                    report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.datasets import tuple_dataset
from chainer.training import extensions



app = Flask(__name__)
CORS(app)

base = os.path.dirname(os.path.abspath(__file__))
image_size = 100
filter1_size=11
filter2_size=11
filter2_num=16*4*2
booling=8
fc1_dimension=int(math.ceil((((image_size - filter1_size + 1)/booling) - filter2_size + 1)/booling) ** 2 * filter2_num)



class MyDataset(chainer.dataset.DatasetMixin):

    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def get_example(self, i):
        #リサイズ
        origin = Image.open(self.image_paths[i])
        origin_width, origin_height = origin.size
        #中央揃えするための変数
        x_pos = 0 
        # 縮小比率と中央揃えの計算
        if origin_height > image_size:
            width = origin_width / (origin_height / image_size)
            height = image_size
            x_pos = int((origin_height - origin_width) / 2)
        else:
            width = origin_width
            height = origin_height

        # パレットモードの解除用に、判定用レイヤーに張り付け
        judge_layer = Image.new('RGBA', (origin_width, origin_height), (0, 0, 0, 0))
        judge_layer.paste(origin, (0, 0))

        # 中央に揃えるため新規レイヤーに張り付け
        layer = Image.new('RGBA', (origin_height, origin_height), (0, 0, 0, 0))
        layer.paste(origin, (x_pos, 0))

        # 透過を白に塗りつぶす用。
        canvas = Image.new('RGB', layer.size, (255, 255, 255))
        for x in range(origin_height):
            if x < x_pos:
                continue
            if x > x_pos+origin_width:
                continue
            for y in range(origin_height):
                pixel = layer.getpixel((x, y))
                # 透過なら白に塗りつぶし
                if pixel[3] == 0 :
                    canvas.putpixel((x, y), (255, 255, 255))
                else:
                # 透過以外なら、用意した画像にピクセルを書き込み
                    canvas.putpixel((x, y), (pixel[0], pixel[1], pixel[2]))

        resize = canvas.resize((image_size, image_size), Image.BICUBIC)

        resize = np.asarray(resize, dtype=np.float32)
        resize = resize.transpose(2, 0, 1)
        label = self.labels[i]
        return resize, label

class MyModel(Chain):
    def __init__(self):
        super(MyModel, self).__init__(
            cn1=L.Convolution2D(3,16*4,filter1_size),
            cn2=L.Convolution2D(16*4,filter2_num,filter2_size),
            fc1=L.Linear(fc1_dimension,252),
            fc2=L.Linear(252,2),
        )

    def __call__(self, x,t):
        return F.softmax_cross_entropy(self.fwd(x),t)

    def fwd(self, x):
        h1 = F.max_pooling_2d(F.relu(self.cn1(x)),booling)
        h2 = F.max_pooling_2d(F.relu(self.cn2(h1)),booling)
        h3 = F.dropout(F.relu(self.fc1(h2)))
        return self.fc2(h3)



@app.route('/',methods=['GET','POST'])
def index():
    #POSTデータを読み込む
    enc_data  = request.form['file']
    dec_data = base64.b64decode(enc_data) 

    #画像をデータセット化
    image_files = [BytesIO(dec_data)]
    labels = [0]
    dataset = MyDataset(image_files, labels)

    model = MyModel()
    serializers.load_npz(os.path.join(base, 'result/model'), model)
    x = Variable(np.array([dataset[0][0]], dtype=np.float32))

    out = model.fwd(x)
    ans = np.argmax(out.data)

    result = {'result': str(ans)}
    return make_response(jsonify(result))


if __name__ == '__main__':
    app.run(debug=True)

フロントのコードも一応貼っておきます

    <script type="text/javascript">
    $(function() {
      var file = null;
      var base64_body = null;
      var img = new Image;
      $('input[type=file]').change(function() {
        $("#result").text("")
        file = $(this).prop('files')[0];
        if (file.type != 'image/jpeg' && file.type != 'image/png') {
          alert("jpgかpngのみです")
          return;
        }

        var reader = new FileReader();
        reader.onload = function(e) {
          $("#preview").attr("src",e.target.result)
          img.src = reader.result;
          base64_body = e.target.result.split('base64,')[1];
        }
        reader.readAsDataURL(file);
      });


      // 判定ボタン
      $('#upload').click(function(){
        if(!file || !base64_body) {
          return;
        }
        if (img.width > 600 || img.height > 1400) {
          alert("サイズが"+img.width+"×"+img.height+"です。縦600px横1400pxまでです")
          return;
        }
        $("#upload").prop("disabled", true);
        $("#loading").show()
        var name, fd = new FormData();
        fd.append('file', base64_body); // ファイルを添付する
        $.ajax({
          url: "http://54.178.179.94/",
          type: 'POST',
          dataType: 'json',
          data: fd,
          processData: false,
          contentType: false,
          crossDomain: true
        })
        .done(function( data, textStatus, jqXHR ) {
          // 送信成功
          if (data.result == 1){
            $("#result").text("立ち絵だと思う")
          } else {
            $("#result").text("多分違う")
          }
        })
        .fail(function( jqXHR, textStatus, errorThrown ) {
          // 送信失敗
          $("#result").text("サーバーが死んでるかも")
        })
        .always(() => {
            $("#loading").hide()
            $("#upload").prop("disabled", false);
        }); 
      });
    });
    </script>

アップされた画像を100×100にリサイズする部分、最初はjavascriptで書いていました。
フロントでどうにかしないとサーバーの負荷がやばいと思ったからです。
その結果、精度が下がってしまいました。

モデル学習用の画像を作るときのリサイズ方式と、
フロント側のリサイズ方式が異なっていたため精度が下がってしまったのだと思います。
全部javascriptとcanvasでリサイズすればよかったなと今更思いました

モデルについては自由につかってくださって構いませんが、
コード自体は本や他のサイト様の継ぎはぎなのであくまで参考程度でお願いします

うまくいかなかった点と対策

枚数増やしたら精度下がった
→削減後の画像は無作為に選ばれているか?目視でノイズ除外したら精度あがった

ローカルじゃ学習が終わらなかった
→google colabで解決

テスト結果がなんかおかしい
→正解率だけでなく、立ち絵とnot立ち絵それぞれの正解率を確認して判断する

API化したら精度が。
フロント側でリサイズしたせいで学習データのリサイズ方法と違う

その他

・「行列」や「なんたら関数」など機械学習は怖いイメージがあったがその辺は全部pythonがやってくれる
・今回出来たmodelはここに置いておきます
・このモデルを使って何するかは考え中
・APIサーバーt2smallなのでちょっと負荷がかかったら死ぬかも
・画像集めるのが一番大変だった。

最後まで読んでいただきありがとうございました。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away