Step0.はじめに
画像生成に関して、検討した結果があるので記載してみようと思います
検討したアーキテクチャは下記の3つです
- VAE
- VAE-GAN
-
U-Net+VAE-GAN
なので、この3つでの検討結果が参考になれば幸いです
学習方針&データセット
学習方針
- 使用GPUはT4使用で1hr以内と設定
データセット
- CIFAR-10(車🚙,飛行機✈)
学習評価指標
作成した画像の品質を定量的に判断するため,FID(Frechet Inception Distance)を使用することとした
- 画像を多変量ガウス分布と仮定し、その分布間のFrechet 距離を算出し、画像の類似性から品質を評価する
- 本検討ではreal100枚、fake100枚の画像でFIDを算出し、定量評価を行う
- FIDは以下の通り
$$
\text{FID} = |\mu_r - \mu_g|^2_2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{\frac{1}{2}}) \tag1
$$
ここで、
- 実データの特徴分布: $X \sim \mathcal{N}(\mu_r, \Sigma_r)$
- 生成データの特徴分布: $Y \sim \mathcal{N}(\mu_g, \Sigma_g)$
- $\mu_r, \mu_g$ はそれぞれの分布の平均ベクトル
- $\Sigma_r, \Sigma_g$ はそれぞれの分布の共分散行列
- $|\cdot|_2$ は2乗ノルム
- $\text{Tr}(\cdot)$ は行列のトレース(対角成分の和)
なお、今回は下記のpytorch-fidを使用して算出した
!pip install pytorch-fid
Step1.VAEによる検討結果
アーキテクチャ
損失関数
損失関数は特に工夫なく、MSEとKLダイバージェンスで実装
再構成損失(Mean Squared Error, MSE)
$$
\text{MSE} = \sum_i | \hat{x}_i - x_i |^2 \tag{2}
$$
- $\hat{x}_i$: 再構成されたデータ
- $x_i$: 元のデータ
KLダイバージェンス
$$
\text{KLD} = -\frac{1}{2} \sum_j \left( 1 + \log(\sigma_j^2) - \mu_j^2 - \sigma_j^2 \right)\tag{3}
$$
- $\mu_j$: 潜在変数の平均
- $\sigma_j^2 = \exp(\text{logvar})$: 潜在変数の分散
総合損失
$$
\mathcal{L} = \text{MSE} + \text{KLD} \tag{4}
$$
- $\mathcal{L}$: モデル全体の損失
RealとVAEで作成した画像の比較
FIDによる評価
FIDスコアはFID:249.76でした
さらに明瞭かつ多様な画像を生成するにはGANの考えが必要かと思い・・・
Step2でVAE-GANを検討することにした
Step2.VAE-GAN
アーキテクチャ
下図の通り.
一般的なGANの場合、Generator(生成器)は通常ノイズ$z$を与えて、画像を生成するアーキテクチャだが、
今回はVAEのDecoderを生成器として構築することとした.
損失関数(GeneratorがDecoderを使用する場合)
GeneratorがDecoderを使用することを考慮した損失関数は以下の通り(式(5))
$$
L_{\text{total}} = \lambda_{\text{VAE}} \cdot \left( \text{MSE}(x, \hat{x}) + \text{KLD}(\mu, \log \sigma^2) \right) + \lambda_{\text{GAN}} \cdot \left( -\mathbb{E}_{z \sim p_z} [\log D(\text{Decoder}(z))] \right) \tag{5}
$$
なお、今回は$\lambda_{\text{VAE}}=1.0$,$\lambda_{\text{GAN}}=0.1$で実施
ここで、以下を明確にします:
- $\hat{x} = \text{Decoder}(\text{Encoder}(x))$: VAEの再構成データ
- $G(z) = \text{Decoder}(z)$: Generator(=Decoder)による生成データ
- $D(G(z)) = D(\text{Decoder}(z))$: Discriminatorが生成データを識別するスコア
RealとVAEで作成した画像の比較
んん、やはりVAEのDecoderだから結局ぼけたまま・・・
FIDによる評価
FIDスコアはなんと・・・FID:254.7でVAE単体よりも悪化
VAE単独とあまり変わらない結果・・・畳み込みの結果が悪いのかと思い、U-Netのようなスキップ接続が解決策になるのではと考え、Step3に続く
Step.3 U-Net+VAE-GAN
従来のVAE-GANでは、GANの考え方を取り込んだとしても、Decoderが十分に機能していないのか生成される結果が明瞭になりませんでした・・・
この課題を解決するために、詳細な特徴を保持しながら高品質な生成タスクを実現できる可能性があるU-NetのEncoderとDecoderを活用したVAE‐GANで学習することにした
特に、U-Netのスキップ接続を活用することで、低解像度の潜在表現と高解像度の局所的な特徴を効果的に統合し、より鮮明で意味のある生成結果が期待できると考え、検討してみました
といことでアーキテクチャは下図の通り
損失関数
損失関数は式(5)と全く同じなので割愛
なお、今回は$\lambda_{\text{VAE}}=1.0$,$\lambda_{\text{GAN}}=0.1$で実施
損失関数の推移は下図の通り
RealとVAEで作成した画像の比較
これまでの検討では学習後VAEのDecoderに潜在変数$z$に乱数を与え、画像を生成して評価してきましたが今回はU-Netなので少し注意が必要です
U-Net Decoderに与える値
1. 潜在変数 ( z )
潜在変数 ( z ) は従来通りランダムノイズを使用
$$
z \sim \mathcal{N}(0, I)\tag6
$$
2. スキップ接続
スキップ接続 $ \text{skip}_1, \text{skip}_2, \text{skip}_3, \text{skip}_4 $ は以下の手順で算出
- 教師データ $ x_r $ をEncoderに通して得られるスキップ接続の出力を基準とする
- 各スキップ接続に乱数を加えることで変動を付与
$$
\text{skip}_i = \text{Encoder}_i(x_r)・(1+α・\epsilon_i)\tag7
$$
ここで,$\text{Encoder}_i(x_r)$はEncoderの出力によるスキップ接続の特徴マップ、$\epsilon_i \sim \mathcal{N}(0,I)$はスキップ接続に加える乱数、$\alpha$はノイズスケールの係数とする
3. Decoderへの入力
U-NetのDecoderに与える値は、潜在変数 ( z ) とスキップ接続 $ \text{skip}_1, \text{skip}_2, \text{skip}_3, \text{skip}_4 $ の組み合わせとなる
$$
\hat{x} = \text{Decoder}(z, \text{skip}_1, \text{skip}_2, \text{skip}_3, \text{skip}_4)\tag8
$$
上記を使って、やっと明瞭な画像が生成できるようになってきた・・・
FID
そしてFIDスコアは162.29と定量的にも改善が見て取れます
U-Net+VAE-GANの課題
生成する際に教師データを一度エンコーダに通して、スキップ接続を得なければならないところに課題があります.幾ら乱数を与えて多様な生成ができるとは言えども,スキップ接続に与える値を完全に乱数に置き換えられない点が課題と考えます.
まとめ
以上、検討の結果は表の通りです
U-Netのスキップ接続を活用することで、低解像度の潜在表現と高解像度の局所的な特徴を効果的に統合し、より鮮明で意味のある生成結果ができると考えます
アーキテクチャ | FID |
---|---|
VAE | 249.76 |
VAE-GAN | 254.7 |
U-Net+VAE-GAN | 162.29 |
コードについて後日Githubで公開します