はじめに
GANは気になっていたのですが、勉強できていなかったので今回勉強したことをまとめていきたいと思います。
今回の記事は実装はありません。概念的なところにフォーカスしています。
以下に内容のリストを書いておきます。
1. GANとは
2. GANを勉強する上で必要な数学的知識
3. GANの評価について
4. 今回紹介するGANの一覧
1. GANとは?
GAN(Generative Adversarial Network)とは、生成器(Generator)と識別器(Discriminator)の2つからなります。
生成器(G)は、偽物の画像を生成します。この時できる限り本物に近い画像を生成するように学習させます。識別器(D)は、本物の画像か生成器(G)で生成された画像か分類します。つまり、
生成器(G) : Dを騙せるように画像を生成する
識別器(D) : 本物の画像かGが生成した画像か判別する
GANの学習
GANでは、基本的にGとDを学習します。
①識別器(D)の学習について
ノイズzを生成器Gに入れ偽物の画像を生成(G(z))した画像と、実際の画像xを識別器Dに判定させます。このとき識別器Dは、本物の画像なら1を、偽物なら0を出力するように学習します。その誤差を利用して識別器Dを学習させていきます。
モデル | 目的 |
---|---|
識別器D | できるだけ正確な結果を返す。本物のサンプルについてはできるだけを正解ラベルである1を返す。偽のサンプルについてはできる限り0に近い値を返す。 |
②生成器(G)の役割
次に生成器の学習です。生成器Gは、ノイズzを生成器Gに入れ生成された偽物の画像を識別器Dに判定させます。このときの誤差を逆伝搬して生成器Gをupdateしていきます。この際に、識別器Dが学習されるのは良くないので学習しないように重みは固定します。
モデル | 目的 |
---|---|
生成器G | 識別器が間違うような画像を生成する。D(G(x))をできるだけ1に近づける |
③目的関数
GANの目的関数は以下の式です。生成器は目的関数を最小化、識別器は目的関数を最大化するのを目指します。
\min_{G} \max_{D} V(D, G) = E_{x \sim p_{data(z)}}[logD(x)]+E_{x \sim p(z)}[log(1-D(G(z))]
G : generator 生成器
D : discriminator 識別器
z : noise
第一項
E_{x \sim p_{data(z)}}[logD(x)]
第二項
E_{x \sim p(z)}[log(1-D(G(z))]
識別器D
できるかぎる上記の目的関数を最大化する。第一項目には、$D(x)=1$ を出力するように第二項目は$ D(G(z))=0 $となれば、$log(1-D(G(z)))$ は最大化できる。
生成器G
生成器は上記の目的関数を最小化するように学習していく。第一項目は、識別器に関係する項なのでスルーします。第二項目は、$D(G(z)) = 1$ となるのが目標です。つまり、生成器が生成した画像を識別器が間違えて判定してくれれば良いことになります。
#2. GANを勉強する上で必要な数学的知識
情報量
以下でKL divergence と JS divergenceの説明をするがその前に、情報量について少し考えてみます。
まず、情報量とは、その情報がどれだけ価値があるかということです。小さいほど価値があります。
ここで、予想しにくさを$p$とおきます。つまり、$ p = 0.1$は予想するのが難しく、 $ p = 0.9$は予想しやすいみたいな感じです。ここでやりたいことは小さい値の時は価値があるから重要で、大きい値の時は価値が小さいようにすることです。それを表現することは簡単で、逆数を取れば良いのです$\frac{1}{p} $。$ \frac{1}{0.1} = 10, \frac{1}{0.9} = 1.11..$。良い感じで表現できましたが今後の計算のし易さを考慮して対数をとります。$log(\frac{1}{p(x)}) = -log(p(x))$
情報量 | 数式 | 説明 |
---|---|---|
①自己情報量 | $-log(p(x))$ | 単一事象のみ |
②平均情報量 | $-\sum_{x} P(x)log(P(x))^*$ | 標本空間全体の情報量 |
*各事象の予想しにくさの加重平均
また、エントロピーという言葉がよく使われますが、予想のしにくさを表しています。例えば、表と裏があるコインの表が出る確率は0.5だとすると実際にコイントスしたときにどちらがでるか予想しにくいですよね。このわからない(予想しにくい)時にエントロピーは大きくなります。
本題に入ります。
Kullback-Leibler(KL) divergence
- 交差エントロピーとエントロピーの差
- 確率分布の距離を測ることができる(正確には距離でないが)
- 確率分布$p(x), q(x)$ が同じならKLはゼロ、違うほど大きな値になる
D_{KL}(p \mid\mid q) = -\int p(x)\log \frac{q(x)}{p(x)}dx
Jensen-Shannon(JS) divergence
- KL divergenceが対称性がなかったので、距離として扱えるように改善
- 0に近いほど2つの確率密度関数は一致している
M=\frac{1}{2}(p+q)\\
D_{JS}(p \mid\mid q) = \dfrac{1}{2}D_{KL}( p \mid\mid M ) + \dfrac{1}{2}D_{KL}( q \mid\mid M )
#3. GANの評価について
生成した画像をどのように評価するのか?私自身も疑問を抱く点でした。
有名な指標を以下で2つ上げます。
1. Inception Score (IS)
ISでは、生成された画像の良さを以下の2つで測ります。
① 識別器が識別しやすい
② 物体(オブジェクト)のクラスのバリエーションが豊富
以下で式を示します。
exp(\frac{1}{N}\sum_{i}D_{KL}(p(y \mid x_{i}) \mid\mid p(y)))
$x_{i}$を画像、yをラベルとするとしています。
この計算では、(1)本物と偽物を分布との間でKL divergenceを計算。(2)この冪乗をとる
ということをしています。
2つの分布間の距離が大きいほど良い画像が生成できている。ISは大きいほど性能が良い。
2. Frechet Inception Distance (FID)
- GANの生成画像の問題として、生成されるサンプルのバリエーションが足りないというものがある
- これに対し、FIDが提案された
- FIDはISをよりノイズに強くし、クラス内でサンプルが反映されなくなっていく現象を検出できるように改良
- 実画像と生成画像の分布間の距離をはかる → Frechet距離
- 特徴表現、つまり特徴マップや層を比較する
- 本物と生成物に埋め込まれた分布の平均値や分散、共分散といった特徴量の距離をはかる
特徴ベクトルの平均と共分散行列が実画像についてそれぞれ、$m_{w}$、$C_{w}$、生成画像について $m$, $c$と得られているとすると
\|\boldsymbol{m}-\boldsymbol{m_w}\|^2_2 + {\rm Tr}\bigl(\boldsymbol{C}+\boldsymbol{C_w}-2(\boldsymbol{C}\boldsymbol{C_w})^{1/2}\bigr)
- 画像は高次元に埋め込まれているので簡単に分布の距離を計測できない
- モデルを使って画像を低次元に埋め込んでその空間で分布の距離をはかる
4. 今回紹介するGANの一覧
- GAN(2014.6) (上記で説明済み)
- cGAN(2014.11)
- DCGAN(2015.11)
- Unrolled GAN(2016.7)
- pix2pix(2016.11)
- StackGAN(2016.12)
- WGAN(2017.1)
- PGGAN(2017.10)
- SNGAN(2018.2)
- SAGAN(2018.5)
- CycleGAN (2018.11)
- StyleGAN(2018.12)
今後随時更新、改善していく予定です。
CGAN(2014.11)
- 生成器と識別器にいくつかの追加情報を与えて、条件づけができるように訓練を行う敵対性ネットワーク
- 生成器は訓練データ内の各ラベルに応じたリアルなサンプルを学習
- 識別器は偽のサンプルとラベルの組、本物のサンプルとラベル組を見分ける
- 識別器を騙すためには、CGANの生成器はリアルな画像を生成するだけではダメでラベルにもマッチしなくてはいけない
- 生成器の訓練が終われば自分が生成したい画像を生成することができる
(1)CGAN生成器
G(z,y) = x^* \mid y
- yを条件とした$x*$を生成
(2)CGAN識別器
- ラベル付きの本物の組$(x,y)$と、偽物の画像にそれを生成するためのラベルが付いた組$(x^* | y, y)$を受ける
- 本物のサンプルとラベル組では、識別器は本物のサンプルをどう認識するかということと、マッチする組み合わせをどう認識するかの両方を学習
- 識別器Dを騙すには、CGANの生成器Gはリアルな画像を生成➕ラベルにもマッチが必要
\min_{G} \max_{D} V(D, G) = E_{x \sim p_{data(z)}}[logD(x|y)]+E_{x \sim p(x)}[log(1-D(G(z|y))]
DCGAN(2015.11)
- DCGANはGANの生成器と識別器として畳み込みニューラルネットワークを用いたもの
- バッチ正規化の利用
- LeakyReLUの利用。GANでの学習で勾配消失にならないようにLeakyReLUを使用。ReLUでは、逆伝搬での重み更新で、0以下の誤差が上層に伝搬できない。LeakyReLUでは、0以下でも出力が0にならない。その分LeakyReLUのハイパラが増える。
- Pooling層がない
####技術やキモは?
-
バッチ正規化
-
勾配消失を防ぐ
-
学習を早く進行できる
\hat{x} \leftarrow \frac {x - \mu_B}{\sqrt{\sigma^2+ \epsilon}}\\
y \leftarrow γ \hat{x} + β
γとβは訓練させるパラメータ
Unrolled GAN(2016.7)
- GANやDCGANでは学習の際に識別器が早く学習してしまう
- UnrolledGANでは、生成器の学習の際にKステップ学習した後の識別器を利用して勾配を計算
- 生成器に先取り学習させて、識別器とのバランスをとるイメージ
pix2pix(2016.11)
- 教師ペアが必要なドメイン間の変換
- CGANの一種で条件として画像を与える
- 生成器の出力を正解ラベルとなるようにL1Lossを採用
- L1Lossは、ピクセル単位で正解画像に近いような画像を生成
- U-Netを利用して画像変換を行う。画像変換では、変換前と変換後で画像は一定レベルで同じ要素を持つのでU-Netのencoderで圧縮した情報だけでなく、skip connectionの情報も加え合わせることでより具体的な情報をもとにデコードした。
- 識別器による判定を画像全体を見て一度だけでなく、小さい領域に分割して判定を行う(PatchGAN)。
StackGAN(2016.12)
WGAN(2017.1)
- GANの訓練誤差としてJS divergencceが使われるが低次元多様体でサポートされたデータでは機能しない
- EM距離を利用
JS divergenceの問題点
- 真の分布$P_{data}$ とモデルの分布$p_{g}$の分布が重ならない➡︎勾配消失
- モード崩壊が起こる
- Loss関数の収束性に問題があり、学習が不安定
技術やキモは?
- Eath-Mover距離(Wassertein距離)
WGAN-gp
- Wassertein距離によって本物のデータと生成データの確率分布の差を測り、最小化していく
- 識別器Dの役割を「real or fakeの判定」から「real or fake のwassertein距離を推定する」に変更
- 損失関数 = (Wassertein距離) + (Gradient Penelty)
- クリッピングという方法でリプシッツ連続性を保証
PGGAN(2017.10)
- 従来のGANとは異なり、同時に学習可
- 識別器と生成器ネットワークを訓練中にプログレッシブに成長させる
- 訓練を安定化させ、より多様で高品質、高解像度の画像を出力
- 徐々に解像度を上げていくことで、latent vectorsからのマッピングを発見するというゴールをいきなり達成するよりは簡単なタスク
技術やキモは?
- ミニバッチ標準化
- 学習率の平坦化
- ピクセル毎の特徴正規化
SNGAN(2018.2)
- 大事なのは、リプシッツ連続であって別にWassertein距離でなくても良くない?
- 学習の安定化。BNをSpectral Normに変更。リプシッツ制約を満たしGANの安定性が向上。
- 損失関数はHinge Loss
- データ数が多くても少なくてもモード崩壊しない
技術やキモは?
- Spectral Normalization
- 係数行列の特異値分解を使ったNormalization
SAGAN(2018.5)
- Self-attentionを用いて画像の大域的な依存関係を抽出。ある要素に注目した時に他に注目するところはどこか。
技術やキモは?
-
Attentionの導入
-
TTUR
-
GとDで異なる学習率を適応
CycleGAN (2018.11)
- pix2pixの教師なしバージョン
- CGANのアイデアを拡張して画像全体を条件として扱う
技術やキモは?
-
サイクル一貫性損失
-
写像Gと逆写像Fの矛盾を防ぐために導入
-
敵対性損失
-
生成画像がリアルになることを保証。
-
サイクル内の最初のDは特に重要
-
同一性損失
-
正規化のための項を導入して生成画像の色合いが元画像と一貫したものになるようにする。
-
CyclyGANが画像全体の色構造を保存するようにした。
-
画像に対する不要な変更にペナルティを課す
StyleGAN(2018.12)
技術やキモは?
-
Progressive Growing
-
高解像度の画像を生成する手法
-
低解像度の画像の生成から始め、徐々に高解像用のG,Dを追加していく
-
解像度をあげるネットワークを追加しても低解像度の画像を生成するGと判別するDはパラメータを固定せず学習させ続ける
-
AdaIN
-
スタイル変換用の正規化手法
-
スタイルとコンテンツ画像の統計量のみで正規化を行い、学習パラメータを使用しない
StyleGAN2
- AdaINの代わりにCNNの重みを正規化することでdropletの除去
- Progressive Growingの代わりにskip connection を持った階層的なGを用いることで、低解像度から順々に学習させる
その他のGAN
更新できたらしていきます。
- pix2pixHD
- BigGAN
- COCO-GAN
- AutoGAN
GAN MEMO
- GANを解析するときは、GとDのアーキテクチャについてまず考える
- GANの安定性 = リプシッツ連続
- original GANは間接的には、JS divergenceを最小化している