オミータです。ツイッターで人工知能のことや他媒体で書いている記事など を紹介していますので、人工知能のことをもっと知りたい方などは気軽に@omiita_atiimoをフォローしてください!
他にも次のような記事を書いていますので興味があればぜひ!
U-Netを識別器に!新たなGAN「U-NetGAN」を解説!
画像生成分野で物凄い成果を出し続けているモデルとしてGenerative Adversarial Networks、通称GANがあります。GANは基本的に 「生成器」と「識別器」の2つのネットワークを用意してお互いに戦わせることでより良い生成器を手に入れよう、というモデルです。GANについては「1.1 GAN」にて簡単に触れます。一方でセマンティックセグメンテーションなどで広く使われているモデルにU-Netというものがあります。U-Netはオートエンコーダーのエンコーダーとデコーダー同士をスキップ結合でつなげたようなモデルです。U-Netについても「1.1.2 U-Net」で簡単に触れます。
そして今回紹介する論文「"A U-Net Based Discriminator for Generative Adversarial Networks", Schönfeld, E., Schiele, B.,Khoreva, A., (CVPR'20)」はGANとU-Netを組み合わせたモデルU-NetGANを提案しています。CVPR2020に採択されています。ではどう組み合わせるか、ということですがタイトルの通り「DiscriminatorにU-Netを使う」のです(ちなみにImage-to-Image変換を行うPix2Pix[Isola, P.(CVPR'17)]ではGeneratorにU-Netを用いていたり、BEGAN[Berthelot, D.(2017)]はDiscriminatorにオートエンコーダーを用いていたりします)。単純な仕組みで、FFHQや本論文オリジナルデータセットのCOCO-AnimalsなどでBigGANを大きく改善させることに成功しています。それではその中身を見ていきましょう!
本記事の流れ:
- 忙しい方へ
-
U-NetGANの説明
- おさらい
- U-NetGAN
- U-NetGANの実験結果
- まとめと所感
- 参考
原論文: "A U-Net Based Discriminator for Generative Adversarial Networks", Schönfeld, E., Schiele, B.,Khoreva, A., (CVPR'20)
実装: 未提供
0. 忙しい方へ
- U-NetGANの特徴は次の2つだよ
- DiscriminatorがU-Net
- CutMixを用いたConsistency Regularizationを導入
- U-NetGANは大域的なフィードバックと局所的なフィードバックの両方をGに伝えることで性能を向上させたよ
- Generatorはいじってないよ
- 新たなデータセットCOCO-Animalsを提案したよ
- BigGANと比べてFFHQで4.00、COCO-Animalsで2.64、CelebAで1.67のゲイン(FID)だよ
- CelebAではSoTAモデルたちよりも良い性能(FID/IS両方)を示したよ
1. U-NetGANの説明
1.1 おさらい
ここではU-NetGANの理解に必要な次の3つについて少しおさらいしていきます。
- GAN [Goodfellow, I.(NIPS'14)]
- U-Net [Ronneberger, O.(MICCAI'15)]
- CutMix [Yun, S.(ICCV'19)]
1.1.1 GAN
GANは上図のように生成器と呼ばれるG(=Generator)と識別器と呼ばれるD(=Discriminator)の2つのネットワークを有しています。これらを交互に学習させることで、最終的に本物の画像と見分けがつかないような画像をGに生成してもらうモデルです。
よく用いられる例えを踏襲すると、Gは偽札製造者でDは鑑定士となります。Gが作った偽札(上図の$G(z)$)をDが本物か偽物かを見極めます(つまり、Dは二値分類)。この時Dには偽札だけでなく本物の札(上図の$x$)もちゃんと見せてあげます。Dの判断結果を基にGはさらに偽札を作ります。それを見て再びDが判断して・・・というのをずっと繰り返すことでGもDもお互いに切磋琢磨しながらGは生成能力をDは識別能力を成長させていきます。ここで$z$とは単なるノイズベクトルです($z\sim\mathcal{N}(0,I)\in\mathbb{R}^{128}$などが用いられることが多いです)。学習済みのGは単なるノイズ$z$をリアルな画像$G(z)$に変換させるので恐るべしです。DとGの最も基本的な損失関数は次のように書かれます。
GANは[Goodfellow, I.(NIPS'14)]で提案され、画像分野ではGにもDにもCNNが用いられることが多いです。このCNNを用いた画像生成GANは[Radford, A.(ICLR'16)]でDCGAN(=Deep ConvolutionalGAN)として提案されています。GANについてもう少し詳しく知りたいよ、という方はこちらの記事で丁寧に書かれていますのでそちらをご覧ください。
今回のU-NetGANでは、Dの部分をU-Netにしたということです。それでは続いてU-Netについておさらいしましょう。
1.1.2 U-Net
まず、U-Netの前にオートエンコーダーから触れます。オートエンコーダとは、エンコーダーとデコーダーで構成されており、その間にボトルネックのようなものを有した上図のようなニューラルネットワークのことでしたね。入力値を再現するようなオートエンコーダーを作ることで途中のボトルネックの出力を「画像を圧縮した潜在表現」のように扱えるのが特徴です。ここで、入力からデコーダーまでの経路が遠いのでデコーダーにも入力に近い情報を直接与えてあげたいです。これを実現させるには単にエンコーダーからデコーダーへのスキップ結合を結んであげれば良さそうです。
上図で点線がスキップ結合を表しています。単に点線を加えただけです。具体的にはエンコーダーの入力に一番近い層をデコーダーの出力に一番近い層に結んであげます。この時、エンコーダーからの値をデコーダーの値に連結(=concatenation)させます。入力から2番目に近い層とデコーダーの出力から2番目に近い層同士、3番目の層同士、、、という具合で同様の処理を各層で行います。これによって出来上がるネットワークは下図のように図示できます。
画像内左側がエンコーダーで右側がデコーダーですね。Uの形をしていますね。そうです、これこそがU-Netです。上図はU-Netの論文[Ronneberger, O.(MICCAI'15)]からそのまま引用したものです。こうすることでフォワードでは入力付近の値をデコーダーに渡せて、バックプロップでは出力付近の値をエンコーダーに効率的に渡すことができそうですね。繰り返しになりますが、GANのDiscriminatorをこのU-NetにするというのがU-NetGANの醍醐味になっています。
1.1.3 CutMix
Images from Pixabay
CutMixは[Yun, S.(ICCV'19)]で提案されたデータオーギュメンテーション手法で、非常に強いです。個人的に用いることが多いのですが、画像分類においてはCutout[DeVries, T.(2017)]やMixup[Zhang, H.(ICLR'18)]など他のDA手法と比べてもとても強いです。そしてCutMixの凄さは強いのにもかかわらず非常にシンプルなことです。上の図で中央の画像がCutMixで作った画像を表しています。CutMixとは、ある画像(上図:図書館)に他の画像(上図:犬)の一部を貼り付けるだけです。新たにできたこの画像(上図:中央)のラベルは元の画像と貼り付けた画像同士の面積比(上図:70%図書館、30%犬)で決めます。これで画像認識モデルの性能が大きく向上するので驚きです。公式実装(PyTorch)はこちらにありますのでぜひお試しください(自己実装も簡単です)。
今回のU-NetGANでは真画像と偽画像をCutMixして、最終的なラベルを偽としています。細かいところは1.2.2にて説明します。このCutMixを用いることでU-NetGANではFIDのさらなるゲインを得ています。
1.2 U-NetGAN
満を持して、U-NetGANの説明です。U-NetGANの大きな特徴は次の2つです。U-NetGANではDiscriminatorのみをいじっています。
- DiscriminatorがU-Net
- CutMixを用いたConsistency Regularizationを導入
1.2.1 DiscriminatorがU-Net
先ほどの「1.おさらい」を読んでいれば、上図はすんなりと入ってくるかと思います。Discriminatorの部分がU-Netになっているだけですね。特に、U-Netのエンコーダー部分を$D_{\mathrm{enc}}^U$、デコーダー部分を$D_{\mathrm{dec}}^U$とします。「U-NetGANは通常のGANにデコーダーをくっつけただけ」であると考えるとよりわかりやすいかもしれません。
このデコーダー部分をつけることで何が良いのでしょうか。これはデコーダーの出力にある心霊写真のような画像を見るとわかります。この画像はピクセル毎の真偽を表しています。明るい部分はDが真であると判断した箇所で暗い部分は偽であると判断した箇所になっています。このピクセル毎の真偽を判定させることで、より局所的な細かいフィードバックをGに与えることができるようです。一方で従来のDは画像全体を見て最終的に真か偽かを判定しているということです。ただ、この全体を見る大域的な判定によるフィードバックもGにとって大切なことは直感的にわかります。なので、U-NetGANではU-Netのボトルネックの部分で従来通りの真偽判定を行なうことでGに大域的なフィードバックも渡しています。下図は学習中の出力を左から右に順に表しています。Dの最終出力においてしっかりと偽物っぽいところは暗くなっており本物っぽいところは明るくなっていることがわかりますね。
それでは損失関数も見てみましょう。まずは通常のGANの損失関数を再掲します。
そしてU-NetGANの損失関数ですが、結論から言うとほぼ元々のGANの損失関数と変わりません。Dがエンコーダーとデコーダーの2つあるのでそれぞれについて通常の損失を取って、足し合わせるだけです。エンコーダーについてはただの通常のDの損失関数です。
デコーダーは各ピクセルごとに真偽判定をするのでそれぞれのピクセルで予測させてあとは平均してスカラーにするだけです。$i,j$はピクセルの位置を示しています。
最終的にDiscriminatorにおける損失はエンコーダーとデコーダーの損失を足し算をしているだけです。ここで、エンコーダーとデコーダーの損失には重みなどを付けずに単純に足し算をするのが良いようです。(実験などによる検証については無記述。)
同時にGeneratorはエンコーダーとデコーダーいずれも騙さないといけないので、次のように書けます。Generatorは単純に両方ともを騙そうとしているだけです。
1.2.2 CutMixを用いたConsistency Regularizationを導入
GANにおいて、Discriminatorへの入力直前で画像のクラスが変わるような変換が加えられても、Discriminatorは真偽判定のみなのでその結果は常に一貫性を持っておくべきです。この一貫性(=Consistency)をDiscriminatorに強制させているものこそが、U-NetGANのもう1つの要素である「CutMixを用いたConsistency Regularization(本記事では以下、CutMix-CRとして参照)の導入」です。もう1つの要素と述べましたが、後述のAblation Studyを見るとCutMix-CRが性能向上にかなり貢献していることがわかります。
ここでのCutMixは以下の3点だけ注意が必要です。
- CutMixは真画像と偽画像の混合
- エンコーダー用のラベルはFakeを用いる
- デコーダー用のラベルは混合で用いたバイナリマスクM
そしてCutMix-CRは 「2枚の画像をCutMixしてからDに通した場合」と「2枚の画像をDに通したのちCutMixした場合」は同じものが出来上がるはずである、という仮定をDに強制させているものになります。図で表すとわかりやすいかと思います。Discriminatorを$D(\cdot)$、CutMixの処理を$T(\cdot)$とすると、CutMix-CRは下図のようなアニメーションで表されます。
A U-Net Based Discriminator for Generative Adversarial Networks, CVPR 2020 (10 min overview)を元に作成
この図で、青色の経路(Discriminatorを通したのちCutMix)と緑色の経路(CutMixを適用したのちDiscriminator)とでは結局同じものが出来上がるはず、ということです。図中にも書いてありますが、$D(T(X)) = T(D(X))$が成り立つようにするということですね。
$D(T(X)) = T(D(X))$を成り立たせればいいわけですから、Discriminatorに$||D(T(X))-D(T(X))||^2$を加えてあげるだけでCutMix-CRの完成です。論文中では以下の式で最終的なDの損失としています。
ちなみに、CutMix画像に対するデコーダーの出力結果は下図のようになっています。マスクの白い部分に真画像、黒い部分に偽画像が貼られています。この結果を見るとしっかりと真のところが明るく、偽のところが暗くなっていることがわかります。
2. U-NetGANの実験結果
2.1 実験条件
-
データセット
- FFHQ: Unconditional, 256x256, 顔
- Celeb-A: Unconditional, 128x128, 顔
- COCO-Animals: Conditional, 128x128, 動物
ここでCOCO-Animals(サイズ:128x128)は、CIFAR-10(サイズ:32x32)の高画質バージョンという位置付けで提案されています。 鳥、猫、犬、馬、牛、羊、キリン、シマウマ、象、猿の10種類の動物が38,000枚ほど含まれています。
-
アーキテクチャ
画像は本論文より引用しております。-
BigGAN: [Brock, A.(ICLR'19)]で提案されたモデル。本論文では以下のアーキテクチャとなっています。(128x128と256x256それぞれ用意されています。)
-
U-NetGAN: BigGANをベースとして、Discriminatorにデコーダーをつけたモデルとなっています。GeneratorはBigGANをそのまま用いているので、Discriminatorだけ示すと下図。
-
-
評価方法:Inception Score(IS)およびFréchet Inception Distance(FID)
実験は大きく4つに区別することができ、以下の順番で順に触れていきます。
- BigGANとの比較
- SoTAとの比較
- アブレーションスタディ
- エンコーダーとデコーダーによる予測結果の一致性
2.2 BigGANとの比較
このテーブルでもうわかりますが、U-NetGANがBigGANを大きく上回っていることがわかりますね。FFHQおよびCOCO-Animalsいずれにおいても見事にBigGANを改良することができています。CelebAにおいてもFIDに対しての性能向上が報告されています。
また、5回の学習過程をFIDカーブで比較してみると下図になります。高い性能もさることながらU-NetGANの分散が小さいことがわかります。U-NetGANがより安定した学習を実現できているということです。
続いて、学習過程をGとDの損失で見てみると下図のようになります。
左図がBigGAN、右図がU-NetGANです。BigGANのDiscriminatorは早い段階で一気に損失が小さくなっていますが、U-NetGANではDの損失はゆるやかに落ちています。Dの損失が一気に損失が落ちているBigGANの方が良さそうですが、実はそれが逆に悪いことだったりします。GANを学習させているとよく「D強すぎる問題」というのにぶち当たります。Dの識別能力がGの生成能力を完全に凌駕してしまい、Gにとって有益なフィードバックが与えられなくなってしまうのです。そのため、Dの識別能力が緩やかに上がっていくようなU-NetGANはGANによる画像生成に向いているということですね。
それではU-NetGANによる生成結果も見てみましょう。まずはCOCO-Animalsから。
リアルですね。馬には人すら乗っています。続いてFFHQです。
かなりリアルです。各行で両端の画像の内挿を行っているのですが、内挿がとても滑らかであることもわかります。ちなみにFFHQデータセットでの生成を行うと、上図中左下のようなピンク髭おじさんが一定数出現するそうです。これは髪がピンク色の人がFFHQに一定数含まれているためだそうです。
2.3 SoTAとの比較
上表はCelebAによる比較ですが、U-NetGANがFIDを2.95まで下げていることが確認できます。論文中では「StyleGAN系に対してもU-NetのDiscriminatorを使ってみるのも面白いだろう」という言及がされています。
2.4 アブレーションスタディ
アブレーションスタディとして、「U-Net Discriminator」、「CutMix」「CutMix-CR」のパーツをそれぞれ加えて行った時にどれだけ結果(FID)に貢献しているかを見てみます。データセットはCOCO-AnimalsとFFHQを用いています。上表から言えるのは次の3つのことがわかります。
- U-Net Dで性能向上
- CutMixは単にDAとして用いても性能向上
- CutMix-CRはとても強い
U-Net DによるゲインはCOCO-Animals/FFHQに対して0.69/1.56ですが、CutMix-CR(上表のCutMixとConsistency Regularizationのゲインの合算)は1.99/3.23のゲインを獲得しているように、CutMix-CRは驚異的な強さを持っていることがわかります。
2.5 エンコーダーとデコーダーによる予測結果の一致性
エンコーダーが真と判断した画像に対してデコーダーも真と判断しているのでしょうか。それともエンコーダーとデコーダーの判断には乖離があるのでしょうか。上図はその結果を示しています。横軸がエンコーダーの予測結果で、縦軸がデコーダーの予測結果です。GANでは真であれば1、偽であれば0というのが慣習となっているので、ここでの値は「真である確率」と考えられます。デコーダーはピクセル全体の平均値をプロットしています。青点それぞれが画像1枚1枚に対応しています。
プロットを見るとエンコーダーとデコーダーの予測結果は一致しているというよりも割とばらつきがあることがわかります。論文著者はこのことを、エンコーダーとデコーダーがそれぞれ異なるフィードバックをしてくれている、と捉えています。上図においてオレンジ色のデータは、「デコーダーは真と言っているがエンコーダーは偽と判断している」ことを意味しています。つまり、「局所的には真っぽいが、大域的に見ると偽である」ということです。実際に画像では、細かく見るとキリンのようなパターンがあり結構本物に見えますが、大域的に見るとこんな象はいないのであきらかに偽物です。このようにU-NetGANはそれぞれがそれぞれの重要な仕事をしていることがわかります。
3. まとめと所感
U-NetGANは、DiscriminatorにU-Netを用いることでBigGANの性能を改善させました。損失関数もシンプルで私にとっては理解しやすくてありがたいです。ただ、U-NetGANの凄さは、U-Netによる大域的かつ局所的なフィードバックもあるとは思いますが、2.4のアブレーションスタディからもわかるように、CutMix-CRを適用できるというところにその凄さがあるように思えます。GANにCRを適用させるというCR-GAN[Zhang, H.(ICLR'20)]やそれを改善した[Zhao, Z.(2020)]など、CRにも目が離せません。あいにく公式の実装はありませんが、本記事でもアーキテクチャを載せたのでそちらを参考に実装してみてはいかがでしょうか!
Twitterで人工知能のことや他媒体で書いている記事などを紹介していますのでぜひフォロー@omiita_atiimoしてください!
こちらもどうぞ:
4. 参考
-
"A U-Net Based Discriminator for Generative Adversarial Networks", Schönfeld, E., Schiele, B.,Khoreva, A., (CVPR'20)
原論文 -
A U-Net Based Discriminator for Generative Adversarial Networks, CVPR 2020 (10 min overview)
論文著者によるCVPR2020での発表動画 -
【論文メモ】 A U-Net Based Discriminator for Generative Adversarial Networks
日本人の方による本論文のまとめがnoteで書かれています。 -
A U-Net Based Discriminator for Generative Adversarial Networks #17
日本人の方による本論文のまとめが書かれています。