#はじめに
先日、某美容院が出している「AIが“似合う髪形”を提案するアプリ」に新機能としてその髪型を、ユーザーの顔写真に合成する機能を追加していたのですが、GANでやったらもう少し綺麗に変換できそうだと思い制作してみました。
#忙しい方へ
##やったこと
男女すべての髪型のデータセットを一人で作るのはさすがにしんどいので今回は女性のロングヘアーとショートヘアーの変換を目標にします。CycleGanをベースにいくつかの変更を加え学習をします。
###変更点
- 損失関数をMSEに変更
- Spectral Normalizationの追加
- Discriminatorで大局的な情報も見るようにする
- 顔情報を別でGeneratorに入力する(最終的にはやめました)
- bCRの追加
###結果
左が元画像で右が変換後の画像です。
※画像はフリー素材を使用しています。
#実装
##データ収集
上にも書きましたが今回は女性のロングヘアーとショートヘアーの変換なのでこの2種類の画像をスクレイピングで集めます。大体一万枚ずつくらい集めましたが、汚いデータが多いのでopenCVで顔を検知し、検知できたものは顔回りを256×256で切り取り、できなかったものは削除していきます。残ったのは半分くらいですね。
まだ誤検知で残った汚いデータもありますが、めんどくさいのでとりあえず見なかったことにします。
##CycleGan
今回は1対1の画像の変換なのでCycleGanをベースに作っていくことにします。そもそもCycleGanとはなにかということですが、2つの画像データセットのdomainの関係を学習して画像変換を実現する手法のことです。有名な例だと馬とシマウマの変換や景観写真とモネの絵画の変換、夏の景観と冬の景観の変換などです。Pix2pixとは違い学習データがアンペアな画像であっても変換できるのが特徴です。
CycleGanは2つのGenerator $G$ , $F$ と、2つのDiscriminator $D_{X}$ , $ D_{Y} $ を使います。$X$をロングヘアー、 $Y$をショートヘアーだとすると、
- $G$はロングヘアーをショートヘアーに変換するGenerator
- $F$はショートヘアーをロングヘアーに変換するGenerator
- $D_{X}$は入力を本物の$X$か、$,F$によって変換された偽物の$X$か判別するDiscriminator
- $D_{Y}$は入力を本物の$Y$か、$G$によって変換された偽物の $Y$か判別するDiscriminator
となります。Generatorは
- Discriminatorを騙すようにする__「Adversarial Loss」__
- $G$($,F$)で変換して$F$($,G$)でもとに戻るようにする__「Cycle Consistency Loss」__
- $G$($,F$)に $Y$($X$)を入力しても変化しないようにする__「Identity Mapping Loss」__
詳しい説明が見たい方は本家の論文か、分かりやすく解説してるサイトがたくさんあるのでそちらを見てみてください。実装は本家のを参考にしようかと思っていたのですが、こちらの方の実装が見やすくてよかったのでこれを参考にベースモデルを作っていきます。ちなみに何も変更なしの通常のCycleGanだとこんな感じです。
想像以上に変換できてて驚きましたが、うっすら消した跡だったりが残ってますね。ここから頑張って精度を上げていきたいと思います。
##修正1
とりあえず損失関数をMSEに変更して、Spectral Normalizationを追加していきます。これによりGANの学習が安定します。Spectral NormalizationはSAGANや、BigGANなどで使われてますね。最近のGANにはわりと使われているらしいです。
###Spectral Normalizationの追加
Spectral Normalizationの前提としてGANの安定性/出力画像の質の向上には、$θ→P_{θ};$ のマッピングが滑らかでなければいけない$=$$D$の損失関数の連続性が重要という考えがあります。これは$D$にリプシッツ連続という強い制約を置くことで実現するのが今のところ主流です。$D$がリプシッツ連続性を有している場合、入力画像が少し変化していても$D$の出力はほとんど変化しなくなります。
Spectral NormalizationではDのNormalizationレイヤーをSpectral Normという係数行列の特異値分解を使ったNormalizationレイヤーに置き換えることで、この制約を実現しているようです。(実際は特異値分解の計算を行うと遅くなるのでPower iterationというアルゴリズムで近似値を出しているらしいです。)
実装はPytorchに公式実装があるのでそれを使います。
##修正2
実はCycleGanは形状の変化を苦手としていると論文で言及されています。CycleGanで有名な馬とシマウマの変換は大きな形状の変化が必要ないので比較的きれいに変換できますが、犬と猫のような形状に大きな変化がある変換は上手くいかないとされています。
###Discriminatorでglobalな情報も見るようにする
なにか方法はないかと調べていたら、「GANで犬を猫にできるか~cycleGAN編(1)~」という記事を発見しました。こちらではDiscriminatorの層を深くし大域的な情報と局所的な情報の2つをフィードバックさせることでかなり高精度な変換を実現していました。敬意をもってパクりたいと思います。
とはいえコードはのってなかったので雰囲気で組んでいこうと思います。Discriminatorの層は普通に深くするとして、損失関数はそれぞれのlossをただ足すだけでいいのか迷いました。調べてるとUNetGANなるものを見つけました。UNetGANも仕組みは違いますが大域的な情報と局所的な情報を返すみたいなので似たようなもんですね。損失関数は重みなどを付けずに足してるだけだったので普通に足して学習することにします。
顔がぐっちゃぐちゃになりました。
層を調整しつつデータセットの精査をします。
###顔情報を別でGeneratorに入力する
今回は髪型の変換なので顔は変化してほしくないわけです。データが十分にある場合は問題ない気もしますが、現状データが少ないうえこれ以上集めるのも大変なので、ごり押しですがGeneratorにひと工夫します。openCVを使い、顔まわりのみのデータを作ります。これをそのまま変換後の画像に張り付けると違和感がすごくなるのでUNetと同じような構造で画像をできる限りくずさないまま送ります。
ついでにGeneratorの損失関数に項目を追加します。顔情報のみの画像は変換前も変換後も同じになるはずなので、2枚の画像の差分のL2ノルムが0になるように学習していきます。
###結果
割ときれいになりましたがopenCVの誤検知で学習が不安定になってる気がします。やはりデータ不足に対応するためにはData Augmentationを使用したいです。
##修正3
機械学習は基本的にデータの数と質が命です。しかし実際問題、毎回十分な量のデータをそろえるのは難しいのでData Augmentation(以降DAと略す)を用いてかさましするのが一般的です。GANではDAに加えてConsistency Regularizationという手法を用いることで高精度な生成ができたという論文があります。今回はConsistency Regularizationを本物画像と偽物画像の両方に用いるbCRという手法を使ってみます。
###bCRの追加
bCRはDAを本物画像と偽物画像の両方に適用し、損失関数に正則化項を追加するだけです。正則化項はDAを適用しようがしなかろうが真偽の判定は変わらないはず、ということからさっきGeneratorでもやったのと同じようにDAを適用した画像の$D$の出力結果と何もしていない画像の$D$出力結果の差分のL2ノルムを0に近づけます。
\begin{align}
LX_{\mathrm{real}}&=||D_{X}(x) - D_{X}(T(x))||^2\\
LX_{\mathrm{fake}}&=||D_{X}(F(y)) - D_{X}(T(F(y)))||^2\;
\end{align}
$LY$もほぼ同じなので省略します。これに重みをかけたものを損失関数に足します。ちなみに$T$はDAを指しています。
###結果
若干劣化した感が否めないです。やっぱりDAになにを使うかもしっかり考えないとだめですね。UNetGANではDiscriminatorをUNetにしてCutMixを使っているのが結構効いてるみたいなので、今度はその辺を流用してみたいですね。
#最後に
パラメーターのチューニングとか、もっとしっかりやれば多少は精度上がるかもしれないです。出力画像に多様性を持たせてみたりとか、やりたいことはまだいくつかありましたが、また気が向いたときにやろうと思います。
#参考
Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros, Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, arXiv preprint arxiv:1703.10593, 2017.
U-Netを識別器に!新たなGAN「U-NetGAN」を解説!