Help us understand the problem. What is going on with this article?

深層学習で綺麗なウミウシを生成する

はじめに

Tensorflowを用いてGANの派生系であるDCGANを行った時の記録をメモ書きしたものです。あまり深いところまでは踏み込まず,ざっくり説明します。
先日もほぼ同じ記事を書いたのですが,ごちゃごちゃなってしまったので少し整理して上げ直します。

タイトルにはウミウシを生成したい!と書いてあるんですが,最初はDCGANでポケモンを生成しようと思っていました。なのでとりあえず,ポケモン生成の試みから簡単に書いていこうかなと思います。

ちなみにウミウシはこんな生き物です。カラフルな種類が多くて綺麗
image.pngimage.pngimage.png

GANやDCGANとは

簡単にGANについて書いておこうかと思います。
GANとは偽物を作る「Generator」判別する「Discriminator」の2つを学習させ,限りなく本物に近いデータを生成するぞってやつです。Generatorは本物データを参考に,ランダムノイズから新たな画像を生成します。DiscriminatorはGeneratorが生成した画像を「偽物or本物」で判別します。GeneratorとDiscriminatorは良きライバルです。これを何回も何回も繰り返す事でGeneratorとDiscriminatorがどんどん賢くなっていきます。その結果,本物データに近い画像を生成するようになります。

↓こんな感じです
image.png

↓さらに簡単に書いたもの
image.png

これがGANの基本的な仕組みです。
このGANにCNN(畳み込みニューラルネットワーク)を使用したものがDCGANになります。CNNは色々複雑なのですが簡単に言うと畳み込み層プーリング層という2つの層を用いてニューラルネットワークを多層構造にする事で,ニューラルネットワーク間で重みを共有することが可能になります。その結果,DCGANはGANよりも効率・精度の良い学習を行う事が可能になります。

このDCGANを使って,ポケモンやウミウシを生成していこうと思います。またGANやDCGANの解説は
今さら聞けないGAN(1) 基本構造の理解
今さら聞けないGAN (2) DCGANによる画像生成
がわかりやすいです。

ポケモン

ポケモンって種類も多いし身近なテーマで楽しそうかなと思い題材に選びました。今のポケモンってこんなに種類いるんですね。
ポケモン画像はここからダウンロードしました。

ちなみに今回,ポケモン画像の収集はChromeの拡張機能である「Image Downloader」を使用しました。コードを書かずにお手軽に使用できるのでおすすめです。あと流石にデータ数が少なすぎると思ったので下記のコードで回転と反転を加えて水増ししました。ちなみに読み込みやすいように.npy形式で保存しています。

import os,glob
import numpy as np
from tqdm import tqdm
from keras.preprocessing.image import load_img,img_to_array
from keras.utils import np_utils
from sklearn import model_selection
from PIL import Image

#クラスを配列に格納
classes = ["class1", "class2"]

num_classes = len(classes)
img_size = 128
color=False

#画像の読み込み
#最終的に画像、ラベルはリストに格納される

temp_img_array_list=[]
temp_index_array_list=[]
for index,classlabel in enumerate(classes):
    photos_dir = "./" + classlabel
    #globでそれぞれのクラスの画像一覧を取得
    img_list = glob.glob(photos_dir + "/*.jpg")
    for img in tqdm(img_list):
        temp_img=load_img(img,grayscale=color,target_size=(img_size, img_size))
        temp_img_array=img_to_array(temp_img)
        temp_img_array_list.append(temp_img_array)
        temp_index_array_list.append(index)
        # 回転の処理
        for angle in range(-20,20,5):
            # 回転
            img_r = temp_img.rotate(angle)
            data = np.asarray(img_r)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)
            # 反転
            img_trans = img_r.transpose(Image.FLIP_LEFT_RIGHT)
            data = np.asarray(img_trans)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)

            X=np.array(temp_img_array_list)
            Y=np.array(temp_index_array_list)

np.save("./img_128RGB.npy", X)
np.save("./index_128RGB.npy", Y)

DCGANでいっぱいポケモン混ぜてキメラチックなポケモンを作りたかった
image.pngimage.pngimage.pngimage.pngimage.png

character_kimera_chimaira.png

でも実際にできたのは
image.png

生成された画像とlossの両方から見てわかるように明らかに過学習していました。Discriminatorがめちゃめちゃ強い。そこで次に原因を考え解決していきました。

過学習の原因

ポケモンは難しい?

  • ポケモンは色も形もバラバラだからカオスなヤツが生成されやすい?
  • 形がある程度統一されているものを題材にしたい。ここでポケモンの生成からウミウシの生成に変更します。
  • と言っても,ウミウシって色はもちろん形もあんまり統一されてないから題材としては微妙な気はする。けど好きなもの作る事がモチベーション保つんやで,と言い聞かせてみます。

データ数が少ない

  • ウミウシの画像は,ポケモン画像より500枚+くらい集めました。回転(-20°~20°)と反転で多分16倍になるから「500 x 16 = 8000」だけデータ量が増えました。
  • 画像の収集はFlickricrawlerで収集しました。
  • ざっくりFlickrの使い方を説明しようと思います。Flickr APIのサイトAPI keyと書かれているところへ移動します。 名称未設定ファイル (1).png ここでYahooのアカウントを取得,ログインするとこの画面になるのでここからkeyを取得します。(黒塗りのところです) 名称未設定ファイル (2).png このkeyを利用して↓のコードで画像を取得します
from flickrapi import FlickrAPI
from urllib.request import urlretrieve
from pprint import pprint
import os, time, sys

# APキーIの情報
key = "********"
secret = "********"
wait_time = 1

# 保存フォルダの指定
savedir = "./gazou"

flickr = FlickrAPI(key, secret, format="parsed-json")
result = flickr.photos.search(
        per_page = 100,
        tags = "seaslug",
        media = "photos",
        sort = "relevance",
        safe_search = 1,
        extras = "url_q, licence"
)

photos = result["photos"]

# ループ処理でphotoに情報を格納する
for i, photo in enumerate(photos['photo']):
    url_q = photo["url_q"]
    filepath = savedir + "/" + photo["id"] + ".jpg"
    if os.path.exists(filepath): continue
    urlretrieve(url_q, filepath)
    time.sleep(wait_time)

これである程度データは集まるのですが,さらに欲しかったためicrawlerで画像を収集します。使い方はめちゃくちゃ簡単です。

$ pip install icrawler
from icrawler.builtin import GoogleImageCrawler

crawler = GoogleImageCrawler(storage={"root_dir": "gazou"})
crawler.crawl(keyword="ウミウシ", max_num=100)

これだけで指定したフォルダにウミウシ画像が保存されます。
f4b244b3be30f5fed4837d57fb64219c.jpg
この画像もポケモン同様に,回転と反転を加えて水増ししました。

ドロップアウトがない

  • ドロップアウトを簡単に説明すると,設定した割合のノードを無視する事で過学習を防ぐものです。
  • 詳しくはこの記事が良さそうです。
  • 下は実際にDiscriminatorにドロップアウトを適用したものです。
def discriminator(x, reuse=False, alpha=0.2):
    with tf.variable_scope("discriminator", reuse=reuse):
        x1 = tf.layers.conv2d(x, 32, 5, strides=2, padding="same")
        x1 = tf.maximum(alpha * x1, x1)
        x1_drop = tf.nn.dropout(x1, 0.5)

        x2 = tf.layers.conv2d(x1_drop, 64, 5, strides=2, padding="same")
        x2 = tf.layers.batch_normalization(x2, training=True)
        x2 = tf.maximum(alpha * x2, x2)
        x2_drop = tf.nn.dropout(x2, 0.5)

        x3 = tf.layers.conv2d(x2_drop, 128, 5, strides=2, padding="same")
        x3 = tf.layers.batch_normalization(x3, training=True)
        x3 = tf.maximum(alpha * x3, x3)
        x3_drop = tf.nn.dropout(x3, 0.5)

        x4 = tf.layers.conv2d(x3_drop, 256, 5, strides=2, padding="same")
        x4 = tf.layers.batch_normalization(x4, training=True)
        x4 = tf.maximum(alpha * x4, x4)
        x4_drop = tf.nn.dropout(x4, 0.5)

        x5 = tf.layers.conv2d(x4_drop, 512, 5, strides=2, padding="same")
        x5 = tf.layers.batch_normalization(x5, training=True)
        x5 = tf.maximum(alpha * x5, x5)
        x5_drop = tf.nn.dropout(x5, 0.5)

        flat = tf.reshape(x5_drop, (-1, 4*4*512))
        logits = tf.layers.dense(flat, 1)
        logits_drop = tf.nn.dropout(logits, 0.5)
        out = tf.sigmoid(logits_drop)

        return out, logits

学習率が高い?

  • 学習率が高いと訓練は早く進むのですが発散しやすく,学習が難しくなります。
  • 1e-2からはじめ,実際に様々な値で検証したところ,1e-4がちょうどいいかな?と言う感じでした。私の場合は1e-5になるとあまりにも学習が遅くなりました。
  • 様々な学習率の挙動についてはこの記事がわかりやすいです。

訓練データが多すぎる?

  • 当初は8:2程度でしたが,6:4に変更しました。あまり効果は実感できませんでした

ウミウシ(ポケモンからの改善結果)

100epoch
ダウンロード (4).png

200epoch
ダウンロード (8).png

300epoch
ダウンロード (10).png

400epoch
ダウンロード (6).png

500epoch
ダウンロード (3).png

  • とりあえず500epochほど回してみました。遠目で見るとなんとなーくウミウシが生成されている気はします。
  • でも出来栄えとしては正直微妙...
  • 考えられる要因としては「epochが足りなかった?」「画像に余計なもの(背景の岩場など)が含まれすぎていた?」「層が深すぎた?」「やっぱりもう少しシンプルな画像にすればよかった?」など様々な事が考えられます。
  • さらに改善してもう少し回したかったのですがGoogle Colaboratoryで実行しており,接続時間もあってなかなか難しい。
  • Colaboratoryについても書きたい事が少しあるので,次に章を設けます。

Colaboratory

ColaboratoryとはGoogleが提供しているクラウド上で動くJupyter notebook環境で80万円くらいのGPUを利用できます。しかも環境構築やDatalabのような申請は不要です。さらに無料。めちゃくちゃ便利なのですが,その分次のような制限があります。

  • 1日にある程度の時間(ここ最近では4時間[500epoch]程度)GPU接続しているとその日は使えなくなります。(これはColaboratoryのGPUリソースが不足しているためで,対処法はなく待つしかありません。GPUはコンスタンスに利用していないユーザーへ優先的に割り当てられるそうです。)
  • 非アクティブの時に90分,最大で12時間経つとランタイムが切断され,ノートブックの学習結果なども初期化されます。
  • そこで,90分問題解決のためにHyperdashを用いました。これによって90分以上ランタイムを接続する事ができます。
  • Hyperdashへ学習ログを送る事で90分問題に加え,Colaboratory上でログが確認できなくなるBuffered data was truncated after reaching the output size limit.問題も解決できます。
  • Hyperdashはスマホのアプリなので,外出先でもログを確認でき便利です。
  • Hyperdashでは学習経過のプロットやパラメータを確認する事もできますが今回はランタイム切断対策だけが目的なので,下の手順だけでOKです。
# 先にスマホアプリのHyperdashを起動してアカウントを作成しておく
# Hyperdashのインストール

!pip install hyperdash
from hyperdash import monitor_cell
!hyperdash login --email

Hyperdashのメールアドレスとパスワードを求められるので入力します。
名称未設定ファイル (4).png
次にHyperdashを使用するコードを書けばOKです。

# Hyperdashの使用

from tensorflow.keras.callbacks import Callback
from hyperdash import Experiment

class Hyperdash(Callback):
    def __init__(self, entries, exp):
        super(Hyperdash, self).__init__()
        self.entries = entries
        self.exp = exp

    def on_epoch_end(self, epoch, logs=None):
        for entry in self.entries:
            log = logs.get(entry)            
            if log is not None:
                self.exp.metric(entry, log)

exp = Experiment("任意の名前")
hd_callback = Hyperdash(["val_loss", "loss", "val_accuracy", "accuracy"], exp)


~~~訓練実行コード~~~


exp.end()

これでスマホアプリのHyperdashをみてみると,学習のログが表示されているはずです。
Hyperdashの使用により90分問題は解決しましたが,何かしらの理由でランタイムが切断される事もあるので,訓練は小分けにして.ckptで保存しておくのがいいかと思います。この.ckptもランタイムが切断されると消えてしまうので早めに保存しておきましょう。

# 学習結果を.ckptで保存する
saver.save(sess, "/****1.ckpt")

# .ckptで保存された学習結果を読み込み,そこから再開する
saver.restore(sess, "/****1.ckpt")

# .ckptを指定したディレクトリに保存
from google.colab import files
files.download( "/****1.ckpt.data-00000-of-00001" ) 

反省・おわりに

  • DCGANはモデルが複雑になるため,過学習が起こりやすく難しい。もっと層を浅く単純なモデルの構築を1番に考える。
  • 過学習とは直接関係なさそうだけど,上に挙げた「epoch数」「画像をシンプルに」「題材はもっとシンプルなものにする」事にも注意する。
  • あと潜在変数も結構重要なパラメータなのかな?もっと調べて見ようと思う。
  • 自分がやってきた事をダーっと書いただけなので,読みづらい記事だったかもしれません。最後まで読んでいただいてありがとうございます。DCGANは結果が画像として現れるので楽しいです。また改善や変更を加えてチャレンジしてみようと思います。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away