概要
画像生成モデルにスタイル変換の考え方を持ち込んだStyleGANのバージョン2が出たので論文を読んでみました。
大幅なアーキテクチャの変更を行いつつ、細かい工夫を効果的に入れることで、前回の結果を超えるモデルを構築することに成功しています。また、本手法で生成した画像を簡単に見抜くことができる手法も同時に提案するなど、盛りだくさんの内容です。
以下のような提案がされています。
- さよならAdaIN:AdaINをWeightDemodulationに変更。これにより、特徴的なアーティファクトを除去できた。
- Lazy Regularization:正則化損失の適用を毎ステップ行わずに、数ステップごとに行うことで高速化。
- Path length regularization:潜在変数$\mathbf{w}$の変化によって生成画像がなめらかに変化することを強制する正則化を追加。
- さよならProgressive Growing:Progressive Growingな訓練プロセスをやめ、MSG-GANを元にしたGenerator側とDiscriminator側の構造を再考。
- Fake検出:本手法で生成した画像を検出するための枠組みを提案。
以上のような工夫によって、どこまで改善されたのかが以下の表です。
全体的な傾向としては、FIDスコアやPerceptual Path Lengthは小さくできていること、Precisionの劣化を抑えながら、Recallを上げている、というのが確認できます。
画像生成モデルにおけるPrecisionとRecallは、直感的にはPrecisionは見た目の良さ、Recallは見た目の多様さを表していると思えば良いでしょう。
最後の提案であるFake検出だけ浮いているようにも見えますが、潜在変数と生成画像との対応関係が明確なモデルができた、という観点から見ると、本論文で提案している内容の妥当性を検証する実験という捉え方もできそうです。
以下、それぞれについて説明を試みます。
さよならAdaIN
前回のStyleGANの大きな特徴は、「スタイル変換手法で発明されたAdaptive Instance Normalization(以下AdaIN)は、より広範な画像生成モデルにも使える!」ということだったわけですが、「AdaINよりももっと良い構造がある!」と、WeightDemodulationを使用することを提案しています。
- StyleGANで生じていたアーティファクトの原因を特定
- 原因を取り除くために、ネットワーク構造を再検討し、よりシンプルな構造に変更
- さらに、より直接的に目的を達成するためにWeightDemodulationを提案
という流れで説明します。
StyleGANのアーティファクト
公式の解説動画に、アーティファクトのわかりやすい例が示されています。
生成結果で生じているアーティファクトは、生成の各段階のActivationで生じている真っ白あるいは真っ黒なスポットに対応していることが、観察の結果わかったそうです。
このようなスポットがどのように生じるのかを、StyleGANの構造をよく見て確認します。
「A」はスタイル情報、「B」はランダムなノイズです。
灰色の箱として表現されているStyle blockに注目してみましょう。
- Step1. 上から入力された特徴量マップの各チャネルを、左から入力されたStyle情報Aによってスケールstdを掛け、バイアスmeanを足し、
- Step2. 重み$w$のConv 3x3を適用し
- Step3. バイアス$b$とノイズ$B$を足し
- Step4. 標準化する
という流れになっています。
Step2のConv 3x3の適用後に特定のチャネルの値が0に近い場合を想像してみましょう。そうすると、そのチャネルの値の主な変動要因は、Step3で追加されるバイアスとノイズになります。そのため、Step4の標準化が適用されてしまうと、バイアスとノイズの影響が拡大され、他のチャネルとの相関を壊してしまします。
これがアーティファクトの原因です。
アーティファクトの除去
アーティファクトの原因は、標準化の前にバイアスやノイズを加えてしまうと、その影響が標準化によって拡大されてしまうことでした。これを取り除くためには、標準化の「後」にバイアスやノイズを加えればいいじゃないか、というシンプルな解決策が思いつきます。
加えて、Step1のスタイル情報による各チャネルのスケールとStep4の正規化から、「mean」の役割、つまりバイアスを足したり引いたりという操作を取り払います。この点については、論文中では軽くしか触れられていませんが、特徴量マップの値がゼロを中心にして変動するという強い
前提をおいても問題ない、というふうに理解しました。
WeightDemodulation
ここまでの変更でアーティファクトの除去には成功するのですが、更により直接的な操作に置き換えます。そもそも、InstanceNormalizationの目的は、Conv層の出力の特定のチャネルのスケールが大きくなりやすいのを一定のスケールに正規化する、ということでした。ここでは、その目的をより直接的に達成するために、Conv層の重み$w$の方を正規化します。そうすることで、入力の特徴量マップが正規化されていさえすれば、実際にConv層を適用せずとも、出力の特徴量マップもある統計的性質を満たす範囲内に収まる、すなわち正規化が期待できるようになります。
Conv層の重み$w$の正規化は、以下のように行われます。
- $w_{i j k}^{\prime}=s_{i} \cdot w_{i j k}$
- $w_{i j k}^{\prime \prime}=w_{i j k}^{\prime} / \sqrt{\sum_{i, k} {w_{i j k}^{\prime}}^2 +\epsilon}$
$i, j, k$は、それぞれ、入力チャネル、出力チャネル、Conv層のカーネルを表します。
Conv層を通すことで生じるであろう出力チャネルのスケールの変動$\sigma_{j}=\sqrt{\sum_{i, k} {w_{i j k}^{\prime}}^2}$を求め、それを用いて正規化しておきます。
このようにすることで、Conv層の重みそのもので出力結果の統計的性質を保証するという事により、安定した出力結果が得られる、というわけです。
おまけ:WeightDemodulationの実装
Conv層の重み$w$の正規化ってどうやるの、と思うわけですが、付録にて、Group Convolutionを使えばよい、という解説がなされています。
参考までに、PyTorchでの実装するとしたら、だいたいこうなるなあ、というのを書いてみました。
import torch
import torch.nn as nn
import torch.nn.functional as F
# バッチサイズ、入力チャネル数、出力チャネル数、高さ、幅を適当に定めます。
N = 8
CI = 32
CO = 64
H = 14
W = 14
# Conv層の重みパラメータw
weight = nn.Parameter(torch.rand(CO, CI, 3, 3))
x = torch.rand(N, CI, H, W) # 入力特徴量マップ
y = torch.rand(N, 1, CI, 1, 1) # スタイル情報
weight = weight.unsqueeze(0) # [1, CO, CI, 3, 3]
weight_prime = y * weight # [N, CO, CI, 3, 3]
weight_prime2 = weight_prime / torch.sqrt((weight_prime**2).sum(dim=(2, 3, 4), keepdim=True) + 1e-9) # [N, CO, CI, 3, 3]
weight_prime2 = weight_prime2.reshape(N * CO, CI, 3, 3)
x = x.reshape(1, N * CI, H, W) # バッチサイズをチャネルのグループ数とみなせるように変形
out = F.conv2d(x, weight_prime2, stride=1, bias=None, padding=1, groups=N) # [1, N * CO, H, W]
out = out.reshape(N, CO, H, W)
Lazy Regularization
StyleGANの訓練にはいくつかの正則化のための損失が適用されていました。しかし、そのうちのいくつかは計算時間がかかるため、毎ステップ(各ミニバッチごとに)実行していると計算時間がかかります。そのため、$k$ステップ毎に適用します。これによって計算速度が向上されます。
付録をみると、具体的には、$R_1$ regularizationと、次の項で説明するPath length regularizationが対象のようです。Genetatorの方では$k=8$、Discriminatorの方では$k=16$とされています。
Path length regularization
前のStyleGANの論文では、Perceptual Path Lengthと呼ばれる評価指標が提案されました。これは、潜在変数$\mathbf{w}$の微小変化が出力画像に与える影響を評価する方法でした。
今回の論文では、$\mathbf{w}$の微小変化が与える影響$\mathbf{J}_{\mathbf{w}}=\partial g(\mathbf{w}) / \partial \mathbf{w}$がどんな$\mathbf{w}$でもある程度一定になるように、以下のような新しい損失を提案しています。
$\mathbb{E}{\mathbf{w}, \mathbf{y} \sim \mathcal{N}(0, \mathbf{I})}\left(\left|\mathbf{J}{\mathbf{w}}^{T} \mathbf{y}\right|_{2}-a\right)^{2}$
なお、$\mathbf{y}$は、標準正規分布$\mathcal{N}(0, \mathbf{I})$からサンプルされるマスクです。
また、$a$は、$\left|\mathbf{J}{\mathbf{w}}^{T} \mathbf{y}\right|{2}$の指数移動平均で、訓練中に値は随時更新されていきます。
さよならProgressive Growing
最初の提案点であるWeightDemodulationは、AdaINに代わるミクロな構造の変更でしたが、ここからはマクロな構造について見ていきます。
以下のような流れで説明します。
- Progressive GrowingなGANの訓練がもたらす悪影響を指摘し
- それを克服できる別のアーキテクチャを適用する
- そのアーキテクチャの分析と改善
Progressive Growingの悪影響
Progressive GrowingなGANの訓練、つまり、低解像度のGeneratorとDiscriminatorの訓練から始まり、徐々に高解像度にしていく手法は、画像の生成モデルではメジャーなやり方ですが、そのStyleGANにおける悪影響を指摘しています。
動画を確認するとわかりやすいです。
https://youtu.be/c-NJtV9Jvp0?t=99
左の女性の画像に引いてある青い線は、画像の中央を表しています。そして、歯の境界が必ずそこになるように画像が生成されていることが、動画から確認できます
また、右の女性の青い円は、女性の右目の瞳の位置を示しています。動画で確認すると、生成画像のポーズの変化に対して、右目の瞳の位置が急峻に変化するタイミングがあります。
これらの現象は、低解像度の生成結果が最終的な生成結果に対して与えてしまっている悪影響であると、論文では主張されています。左の女性は、歯の境界と低解像度の生成結果のピクセル境界が一致しており、右の歯女性は、低解像度の生成結果が切り替わるタイミングで右目の瞳の位置が急峻に変化する、というわけです。
注:Progressive Growingな訓練過程では、中間層に過度に高い周波数を強制するため、シフト不変性を損なう、という主張が論文ではなされていますが、正直意味がよくわかっていません。
MSG-GANベースのアーキテクチャの採用
以上のようなProgressive Growingの課題を克服するために、別の構造として、MSG-GANを採用しています。
以下はMSG-GANの論文から抜粋した図です。
MSG-GANでは、複数の解像度の画像を出力し、それをDiscriminatorに入れるというのは今までと一緒ですが、Discriminatorが解像度ごとにではなく、1つのみ用意されており、各解像度の生成画像はDiscriminatorの各段階で適宜挿入されるという形になっています。
今回の手法では、MSG-GANをベースにした構造(下図の(a))の他に、別の可能性も模索しています。
まず、(b)のスキップ構造を持ったアーキテクチャでは、Generatorの場合は、各解像度の出力結果をUpscaleしながら足し合わせることで出力結果を生成します。また、Discriminatorの場合は、その出力結果をダウンスケールしながら中間のチャネルに挿入します。
次に、(C)のResidual構造をもったアーキテクチャでは、Generatorの場合は、各解像度の生成結果をアップサンプルしたものと、残差構造で詳細を埋めていく構造になっています。Discriminatorの場合は、ダウンスケールした画像と残差構造で処理した画像を足し合わせるということを繰り返しています。
以上、3種類のGenerator/Discriminatorの組わせについて、以下のように9種類の組み合わせを評価ししています。
最終的には、Generatorはskip構造(b)を、Discriminatorはresidual構造(c)を採用しています。
Generatorにおける各解像度の貢献
GeneratorとしてSkip構造を採用すると、最終的な生成画像は、各段階の出力を足し合わせたものになります。そこで、各段階の出力のスケール(標準誤差)を算出し、全解像度での合計が100%となるように正規化すると、各解像度での出力が最終出力に与える貢献度のようなものを算出することができます。
下図の左側がその可視化で、横軸が訓練のエポック数、縦軸が各解像度の貢献の積み上げとなっています。
これを見ると、訓練が進むと徐々に低解像度が担う役割がなくなり、高解像度の出力が担う役割が増えていく、ということがわかります。このような可視化から、低解像度の出力結果が最終的な出力に与えるProgressive Growingの悪影響を取り除くことができた、と言うことができます。
と、ここで終わってもいいはずなのですが、著者らはこの時点での出力結果を確認し、リアルな画像に存在するディテールが欠けて、のっぺりしてしまっていることを発見しました。Generatorの最後の解像度のステージでのチャネル数を2倍することで、この問題を解決しています。このバージョンの各解像度の貢献度を確認したのが上図の右で、最終解像度である1024x1024の貢献度が大きくなっています。
Fake検出
今回提案した手法を使って生成した画像と、実際の画像とは、簡単に見分けることができるということを示しています。
手順としては以下のとおりです。
- 入力された画像$\boldsymbol{x}$から、それに対応する潜在変数($w, z$)を推定します($\tilde{g}^{-1} (\boldsymbol{x})$)。
- このように得られた潜在変数を用いて再度画像を生成します($g\left( \tilde{g}^{-1} (\boldsymbol{x}) \right)$)。
- 元画像と再生成された画像とのLPIPS距離を算出します。この距離が小さいとFake画像、大きいとReal画像というふうに見分けることができます。
$\tilde{g}^{-1}$をどうやって求めるのかは、付録で詳細に解説がなされていますが、この記事での紹介はやめておきます。
上段は、旧StyleGANのFake/Real画像のLPIPS距離の分布の違い、下は今回のStyleGAN2のFake/Real画像のLPIPS距離の分布の違いを表しています。
今回のモデルでは、Fake画像のほうがLPIPS距離が小さい位置に分布しています。これは、今回のモデルによって生成されたFake画像から、潜在変数を求めやすい、ということを意味しています。そのため、再生成したときの画像も近い画像が生成される、というわけです。一方で、Real画像は、Fake画像に比べると潜在変数が求めにくく、再生成したときに少し異なる画像が生成される、というわけです。
以上のような結果から、生成結果と潜在変数が比較的ダイレクトに対応している、という確認もできます。
まとめ
StyleGANのバージョン2に当たる論文をまとめてみました。
細かい問題を丁寧に洗い出し、それぞれの課題に対して様々な既存研究を引くことによって克服している論文です。
存在自体も知らない既存研究も多く、付録の分量も多かったため、読むのに非常に骨が折れましたが、大変勉強になる一本でした。
書誌情報
- Karras, Tero, et al. "Analyzing and improving the image quality of stylegan." arXiv preprint arXiv:1912.04958 (2019).
- https://arxiv.org/abs/1912.04958