1. Qiita
  2. Items
  3. Chainer

初心者がchainerで線画着色してみた。わりとできた。

  • 922
    Like
  • 24
    Comment

デープラーニングはコモディティ化していてハンダ付けの方が付加価値高いといわれるピ-FNで主に工作担当のtai2anです。
NHKで全国放送されたAmazon Picking Challengeでガムテべったべたのハンドやロボコン感満載の滑り台とかを工作してました。
とはいえ、やっぱりちょっとディープラーニングしてみたいので1,2か月前からchainerを勉強し始めました。
せっかくなので線画の着色をしたいなーと思って色々試してみました。

線画の着色は教師あり学習なので線画と着色済みの画像のデータセットが(できれば大量に)必要です。
今回はOpenCVでカラーの画像から線画を適当に抽出しています。

抽出例
6c55da4e7961024012e78b38db048f53ce8e3a31.jpg → 6c55da4e7961024012e78b38db048f53ce8e3a31.jpg

カラーの画像を集めて線画を作ればデータセットの完成です。(今回は60万枚くらい使っています)

ネットワークの形ですが、U-netという最初の方でコンボリューションする時の層の出力を最後の方でデコンボリューションする時にマージして使うネットワークを使いました。線画を参照してそれに着色する時に向いていそうです。(px2pxでもU-netが使われています)
数はとりあえず正解との誤差の二乗をとってそれを最小化するように学習させてみました。
ネットワークの定義自体は層の出力と次の層の入力を一致させるようにすれば大体OKだというのは分かったのですが、自分で定義するdatasetの作り方がsampleにあまりなかったのでちょっとわかりにくいかなと思いました。

一晩ほど学習してみて、テストデータを食わせてみた結果がこちら↓

スクリーンショット 2016-12-25 15.36.09.png

うーん、「肌色はなんとなくわかるけどそれ以外は知らねえよ、イラストキャラの髪とか服の色なんて当てられるわけねーだろ」というニューラルネットの気持ちが伝わってきます。。。

そこで登場するのがアドバーザリアル・ネット。通称アド婆。
アド婆は本物の画像とニューラルネットの着色した画像の差を学習してケチをつけてくるやつです。
なので、肌色セピアっぽい画像ばっかり出してると一発で学習されてダメ出しを食らうようになってしまいます。

ただ、アド婆が強すぎると、へそを曲げた着色側がグレるので要注意

2e3a61e3da383137ae334afe52c1efacd7249a86.jpg 2e3a61e3da383137ae334afe52c1efacd7249a86.jpg 2e3a61e3da383137ae334afe52c1efacd7249a86.jpg

こんなかんじで色は確かについてるけど線画とはなんだったのかみたいなアート系になったりします。
(まあ、これはこれでこういうアートの道もあるかもしれませんが

元画像との差分成分とのバランスに気をつけつつ学習を回してあげると。。。

スクリーンショット 2016-12-25 17.08.51.jpg

色がついてきました!
ひゅー

調子にのってもう一段!
512x512pxの線画を縮小して1段階目で128x128pxを塗って、2段階目で512x512pxを着色するネットワークも学習させました(2段階目もネットワークの形はほとんど同じだが、入力を4chにして新たにトレーニングし直し。こっちはアド婆なしです)

結果がこちら↓ 

スクリーンショット 2016-12-25 15.11.38.jpg

(こちらの線画は@lio8644さんの公開しているものお借りしました。)

いいじゃん。

スクリーンショット 2016-12-25 15.56.10.jpg

悪くないじゃん。

抽出線画をテストに使う限りはかなりよさそうだが実際の線画だとどうか?

pixivの塗っていいのよ系のタグから線画を拝借してテスト。
(基本全部CNNなのでアスペクト比が多少変わっても対応できます。

ででん

スクリーンショット 2016-12-25 15.03.39.jpg

すばらしい。

スクリーンショット 2016-12-25 14.58.21.jpg

カラフルな感じの怪物になったけど、まあこういうやつもいるよね。

スクリーンショット 2016-12-25 17.29.48.jpg

無難にまとめた感じ

ところで、やっぱりここはこの色に塗ってほしいとかそういうのありますよね?
というわけで、一段目の入力も4chに変更して、追加でヒントを与えられるようにしました。

つ ま り

スクリーンショット 2016-12-25 17.49.07.jpg

茶髪でセーター水色、スカート紺で、みたいな要求を設定できるわけです。

ざっくりとこんな感じにしてほしい感を伝えるとけっこうよしなにしてくれる。

スクリーンショット 2016-12-25 14.57.25.jpg

スクリーンショット 2016-12-25 15.03.03.jpg

(もはや線画が良けりゃそりゃ色塗ったらキレイだよねという気分になってきますがw

結構細かく沢山ヒントを与えることも可能です。
(ちょっとわかりにくいですが、結構ぽつぽつヒント点を打ってます

スクリーンショット 2016-12-25 14.53.10.jpg

メリー・クリスマス!!

これで俺も工作担当から図画工作担当に昇格だぜ!

ということで、線画の自動着色とヒントつき着色が、かなりできたのではないかと思います。

絵師の方たちのちゃんとした塗りにはまだまだ勝てないと思いますが、ざっくり塗ってみたりするには便利だと思います。
漫画とかでもトーン貼るよりざっと着色するほうが早ければ便利ですよね。
(こんかいのニューラルネットは肌色にはつよい・・・いいたいことはわかるな?)

一応補足ですが、弱点もまだいくつかあります。

例えばアドバーザリアルネットとヒントを同時に学習しているので、ヒントを与えた時の波及効果が不安定になる場合があります。

スクリーンショット 2016-12-25 16.35.04.jpg

↑水着だけ違う色で塗ってほしかったのに、他の部分の色も大きく変わっています。
簡単着色ツールとしてだけ使うなら、ヒントのみで学習させたほうが安定するかもしれません。

また、一度縮小してから色を塗っているため、線画が太すぎる/細すぎる場合も、線が潰れたり飛んだりして上手くいかないケースがあったり、細かい塗り分けをヒントで指示しても反映されないケースもあります。

細かい所に全部一つのNNで対応できれば格好良いですがツールとして使ってもらうときには用途ごとに調整していく必要はありそうです。

線画拝借先
http://www.pixiv.net/member_illust.php?mode=medium&illust_id=31274285
http://www.pixiv.net/member_illust.php?mode=manga&illust_id=43369404
http://www.pixiv.net/member_illust.php?mode=medium&illust_id=56689287
http://www.pixiv.net/member_illust.php?mode=medium&illust_id=40487409
http://www.pixiv.net/member_illust.php?mode=medium&illust_id=10552795
https://twitter.com/lio8644

※抽出線画についてはちょっと元の参照先をすぐには見つけれなかったので割愛させていただいております。申し訳ありません

ちなみに今回のネットワーク構造は一段目も2段目も構造は同じでこんな感じ

unet.py
    class UNET(chainer.Chain):
        def __init__(self):
            super(UNET, self).__init__(
                c0 = L.Convolution2D(4, 32, 3, 1, 1),
                c1 = L.Convolution2D(32, 64, 4, 2, 1),
                c2 = L.Convolution2D(64, 64, 3, 1, 1),
                c3 = L.Convolution2D(64, 128, 4, 2, 1),
                c4 = L.Convolution2D(128, 128, 3, 1, 1),
                c5 = L.Convolution2D(128, 256, 4, 2, 1),
                c6 = L.Convolution2D(256, 256, 3, 1, 1),
                c7 = L.Convolution2D(256, 512, 4, 2, 1),
                c8 = L.Convolution2D(512, 512, 3, 1, 1),

                dc8 = L.Deconvolution2D(1024, 512, 4, 2, 1),
                dc7 = L.Convolution2D(512, 256, 3, 1, 1),
                dc6 = L.Deconvolution2D(512, 256, 4, 2, 1),
                dc5 = L.Convolution2D(256, 128, 3, 1, 1),
                dc4 = L.Deconvolution2D(256, 128, 4, 2, 1),
                dc3 = L.Convolution2D(128, 64, 3, 1, 1),
                dc2 = L.Deconvolution2D(128, 64, 4, 2, 1),
                dc1 = L.Convolution2D(64, 32, 3, 1, 1),
                dc0 = L.Convolution2D(64, 3, 3, 1, 1),

                bnc0 = L.BatchNormalization(32),
                bnc1 = L.BatchNormalization(64),
                bnc2 = L.BatchNormalization(64),
                bnc3 = L.BatchNormalization(128),
                bnc4 = L.BatchNormalization(128),
                bnc5 = L.BatchNormalization(256),
                bnc6 = L.BatchNormalization(256),
                bnc7 = L.BatchNormalization(512),
                bnc8 = L.BatchNormalization(512),

                bnd8 = L.BatchNormalization(512),
                bnd7 = L.BatchNormalization(256),
                bnd6 = L.BatchNormalization(256),
                bnd5 = L.BatchNormalization(128),
                bnd4 = L.BatchNormalization(128),
                bnd3 = L.BatchNormalization(64),
                bnd2 = L.BatchNormalization(64),
                bnd1 = L.BatchNormalization(32)
        )

    def calc(self,x, test = False):
        e0 = F.relu(self.bnc0(self.c0(x), test=test))
        e1 = F.relu(self.bnc1(self.c1(e0), test=test))
        e2 = F.relu(self.bnc2(self.c2(e1), test=test))
        e3 = F.relu(self.bnc3(self.c3(e2), test=test))
        e4 = F.relu(self.bnc4(self.c4(e3), test=test))
        e5 = F.relu(self.bnc5(self.c5(e4), test=test))
        e6 = F.relu(self.bnc6(self.c6(e5), test=test))
        e7 = F.relu(self.bnc7(self.c7(e6), test=test))
        e8 = F.relu(self.bnc8(self.c8(e7), test=test))

        d8 = F.relu(self.bnd8(self.dc8(F.concat([e7, e8])), test=test))
        d7 = F.relu(self.bnd7(self.dc7(d8), test=test))
        d6 = F.relu(self.bnd6(self.dc6(F.concat([e6, d7])), test=test))
        d5 = F.relu(self.bnd5(self.dc5(d6), test=test))
        d4 = F.relu(self.bnd4(self.dc4(F.concat([e4, d5])), test=test))
        d3 = F.relu(self.bnd3(self.dc3(d4), test=test))
        d2 = F.relu(self.bnd2(self.dc2(F.concat([e2, d3])), test=test))
        d1 = F.relu(self.bnd1(self.dc1(d2), test=test))
        d0 = self.dc0(F.concat([e0, d1]))

        return d0

アド婆

adv.py
class DIS(chainer.Chain):
    def __init__(self):
        super(DIS, self).__init__(
                c1 = L.Convolution2D(3, 32, 4, 2, 1),
                c2 = L.Convolution2D(32, 32, 3, 1, 1),
                c3 = L.Convolution2D(32, 64, 4, 2, 1),
                c4 = L.Convolution2D(64, 64, 3, 1, 1),
                c5 = L.Convolution2D(64, 128, 4, 2, 1),
                c6 = L.Convolution2D(128, 128, 3, 1, 1),
                c7 = L.Convolution2D(128, 256, 4, 2, 1),
                l8l = L.Linear(None, 2, wscale=0.02*math.sqrt(8*8*256)),

                bnc1 = L.BatchNormalization(32),
                bnc2 = L.BatchNormalization(32),
                bnc3 = L.BatchNormalization(64),
                bnc4 = L.BatchNormalization(64),
                bnc5 = L.BatchNormalization(128),
                bnc6 = L.BatchNormalization(128),
                bnc7 = L.BatchNormalization(256),
        )

    def calc(self,x, test = False):
        h = F.relu(self.bnc1(self.c1(x), test=test))
        h = F.relu(self.bnc2(self.c2(h), test=test))
        h = F.relu(self.bnc3(self.c3(h), test=test))
        h = F.relu(self.bnc4(self.c4(h), test=test))
        h = F.relu(self.bnc5(self.c5(h), test=test))
        h = F.relu(self.bnc6(self.c6(h), test=test))
        h = F.relu(self.bnc7(self.c7(h), test=test))
        return  self.l8l(h)