本書は画像生成の各種モデルのベースとなっているGANについて説明します。(CNNの基礎を理解している前提で記載しています。まだ理解していない方は別冊のCNNの基礎を先に読んでください)
1.概要
GAN は、Generator(以降、生成器)と Discriminator(以降、識別器)という2つの機構から構成される生成モデルです。モデルアーキテクチャは図1のようになります。
生成器は乱数から偽物画像を生成し、識別器は生成された偽物画像と本物画像を入力し本物か偽物かを識別させ、識別器が判別できなくなるまでお互いを交互に学習させることで、最終的には生成器が本物に近い画像を生成できるようにするものです。
下図2は、モデルアーキテクチャを見やすくするためにモデルをサマリしたものです。
まず、生成器は潜在変数Zを入力値として受け取り画像を生成して出力します。通常Zは100次元ベクトル(各要素は0~1までの値を取る変数)の変数であり、一様分布や正規分布からランダムサンプリングします(=画像を生成するための種になる)。
識別器は画像データを入力値として受け取り、そのデータが本物のデータなのか、生成器から生成された偽物のデータかを出力値「0.0(偽物)~1.0(本物))として返します。
次に各機器の学習について説明します。
2.識別器(Discrimator)の学習
識別器は画像(本物または偽物)を入力データとして、予測結果(本物か偽物)を出力するので教師データはその画像が「本物(1)または偽物(1)」かになります。すなわち以下のラベルを用いて学習します。
・本物画像入力 -> 教師ラベル=1
・偽物画像入力 -> 教師ラベル=0
3.生成器(Generator)の学習
生成器の学習ではあくまで生成器のみ学習させるので、識別器の重みを固定させて学習させることにポイントがあります(Kerasフレームワークの場合はtrainable=Falseを動的に指定することで実装が可能)。
もう一つ重要なポイントは、生成された画像は当然偽物なので本来ならば教師ラベルは0(偽物)ですが、識別器に対して本物だとだますように学習させるので「教師ラベル=1(本物)」として学習させます。
従って生成器の学習イメージは下記のサマリになります。
4.損失関数
5.GANの学習での課題について
GANの学習については、以下のような課題があります。
(1) 損失関数の収束性
GANは、識別器と生成器の2人プレイヤーゼロサムゲームになっていますが、このゲームの最適解は、ナッシュ均衡点になります。また、2人プレイヤーゼロサムゲームのナッシュ均衡点は、鞍点になります。
式1の形状が、上図のような凸関数であれば、SGDによってナッシュ均衡点(=鞍点)に収束されることが保証されますが、非凸関数の場合は保されません。また、GANではニューラルネットワークにより損失関数を表現するので、関数の形状は非凸関数になります。
このような非凸関数に対して、SGDで最適化(=鞍点の探査)を行っていくとナッシュ均衡点(=鞍点)に行かず、振動する可能性があります。
(2) モード崩壊
学習が不十分な識別器に対して生成器を最適化した場合や、生成器への入力ノイズZの潜在変数としての次元が足りていない場合などにおいて、生成器による生成画像がある特定の画像に集中してしまい、学習用データが本来持っている多様な種類の画像を生成できなくなってしまいます。これをモード崩壊といいます。
(3) 勾配損失問題
学習が十分でない識別器に対して、生成器を最適化するとモード崩壊が発生します。これを防ぐために、ある生成器Gの状態に対しての識別器を完全に学習すると、今度は勾配損失問題が発生してしまいます。
GANでは、モード崩壊と勾配損失問題が互いに反して発生しまうというジレンマを抱えています。
(4) 生成画像の品質判断
生成画像の品質の評価は、損失関数の結果からでは判断が難しい課題があります。
6.おわりに
以上がGANの基礎になります。GANには様々な拡張モデルがあります。今後、高画質な画像生成を行うStyleGAN・StyleGAN2について、CPU上で動かす実証をしつつ執筆予定です。