概要
画像のドメイン変換に連続的な変化を与えることができる手法SAVI2Iを提案している論文を読んだので紹介してみます。
画像のドメイン変換は、pix2pix1に代表される、入力画像を目的となるドメインへと変換するタスクです。本論文は新しい画像のドメイン変換手法SAVI2Iを提案していますが、以下のような特徴があります。
- マルチモーダル:ドメイン変換の方向が多対多である。
- マルチドメイン:複数のドメインを統一的に扱えるモデルを構築できる。
- 連続的:ドメインとドメインの間の補間的なドメインの画像を生成できる。
マルチモーダルかつマルチドメインという手法は結構存在しますが、それに加えて以下のような異なるドメイン画像間の連続的な変換ができるというのが本手法の特徴になっています。
このような補間は今までにもあったような気がしますが、それらのほとんどはドメイン内(intra-domain)での補間に留まっており、本手法はドメイン間(inter-domain)の補間もできるのだと主張しています。
本論文で提案されている新規性は大きく以下の2点です。
- Signed Attribute Vector(SAV)による、複数ドメインのスタイル特徴ベクトル(Attribute Vector)の同一空間での表現。
- Sign-Symmetrical Attribute Vectorによる、ドメイン間補間画像生成の訓練方法。
最近の多くのドメイン変換モデルと同様に、本手法でも大量の損失関数が使用されます。上記の2点について説明した後、それらの損失関数にも簡単に触れることとします。
書誌情報
- Mao, Qi et al. “Continuous and Diverse Image-to-Image Translation via Signed Attribute Vectors.” (2020).
- arxiv
- 公式実装(PyTorch)
- プロジェクトページ
モデルの全体像
まず、本手法で訓練するモデルの全体像を確認しておきます。
(a)2つのEncoder
画像からコンテンツ(Content)とスタイル(Attribute)の情報を抽出する2つのEncoder$E_c, E_a$を使用します。
(b)訓練の枠組み
左上にあるSigned attribute vectors(SAV)が、本手法のキモとなるモジュールです。ここでは、標準正規分布からサンプルされたランダムなベクトルに対し、指定したドメイン$\hat{y}$を表す符号を付与し、スタイル特徴ベクトルを得ています。SAVの細かい内容については後ほど説明します。
左真ん中の男性から$E_c$によってコンテンツ特徴を抽出し、左下の女性から$E_a$によってスタイル特徴を抽出します。ここまでで、1つのコンテンツ特徴と、2つのスタイル特徴が得られたことになりますので、Generator$G$によって2つの生成結果が得られます。
$\mathcal{L}_{\text{MMD}}$と$\mathcal{L}_{\text{adv}}^{\text{domain}}$という2つの損失関数が見えますが、前者は実画像から得られるスタイル特徴ベクトルが、望ましい分布に収まるようにするためのMMD制約損失2を表し、後者は生成された画像が指定したドメインの画像として適切であるかを判定するMultitask-Discrminatorを通じて得られる敵対的損失です。
なお、上図には示されていませんが、コンテンツ特徴にスタイルに関する情報が紛れ込むことを避けるのに使用するDiscriminator $D_c$も存在します。
(c)推論時の動作
推論時には、ある画像から$E_c$を通じてコンテンツ特徴を抽出し、任意のスタイル特徴を適用することで新しい画像を生成できます。この時、目標となるスタイル特徴は、別ドメインの実画像から$E_a$によって抽出することもできますが、SAVによって生成することもできます。また、2つのスタイル特徴の線形補間を行い、それをコンテンツ特徴と組み合わせれば、ドメイン間を連続的に変化する生成結果が得られます。
Signed Attribute Vector(SAV)
SAVは、複数ドメインのスタイルを同一の空間内に埋め込むシンプルな方法です。
SAVでは、まず$Nd$次元のランダムなベクトル$\mathbf{z}^p$を生成します。ここで、$N$はドメイン数、$d$は各ドメインに割り当てられた次元数とします。$\mathbf{z}^p$の各要素は、標準正規分布からサンプルされているため、正の値のときもあれば、負の値のときもあります。
次に、指定されたドメインに対応する要素は全て正に、その他のドメインに対応する要素は全て負になるような操作を加えます。この操作により、SAVから得られるベクトル$\mathbf{z}^{s}$は、ランダムでありながら、特定のドメインに属することが明確であるベクトルになります。
\begin{array}{l} \mathcal{O}_{y}\left(\mathbf{z}^{p}\right)=\left[-\left|z_{1}^{1}\right|,-\left|z_{2}^{1}\right|, \cdots,-\left|z_{d}^{1}\right|, \cdots\right. \\ + \left|z_{1}^{y}\right|,+\left|z_{2}^{y}\right|, \cdots,+\left|z_{d}^{y}\right|, \cdots \\ \left.\quad-\left|z_{1}^{N}\right|,-\left|z_{2}^{N}\right|, \cdots, \left|z_{d}^{N}\right|\right] \end{array}
まとめると、SAVは以下のように数式で表されます。
\mathbf{z}^{s}=\mathcal{O}_{y}\left(\mathbf{z}^{p}\right) \quad \mathbf{z}^{p} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), y \in\{1 \ldots N\}
Sign-Symmetrical Attribute Vector
訓練の枠組みで示した訓練方法でも、マルチモーダル・マルチドメインのドメイン変換に使用できるEncoderとGeneratorは訓練できますが、ドメイン間の補間的な変換が可能になる保証がありません。
そこで、本手法では、先ほどのSAVによるスタイル特徴表現$\mathbf{z}^{s}$を活用し、2つのドメインの線形補間的なスタイルベクトルを作成することを考えます。$\mathbf{z}^{s}$が表すドメイン$y$とは異なるドメイン$\hat{y}$を用意し、それに対応する要素だけを正にしたベクトル$\mathbf{z}_{\mathrm{sym}}$を作成します。
\begin{array}{c} \mathbf{z}_{\mathrm{sym}}=\mathcal{O}^{r}
\left(\mathbf{z}^{s}\right)=
\left[\cdots,+\left|z_{1}^{\hat{y}}\right|,+\left|z_{2}^{\hat{y}}\right|, \cdots,+\left|z_{d}^{\hat{y}}\right|, \cdots\right. \\ \left.-\left|z_{1}^{y}\right|, -\left|z_{2}^{y}\right|, \ldots,-\left|z_{d}^{y}\right|, \cdots
\right] \end{array}
下図の(a)は、$\mathbf{z}^{s}, \mathbf{z}_{\mathrm{sym}}$という2つのベクトルについて、ドメイン$y, \hat{y}$に関する要素のみを軸として取り出し、プロットしたものです。
2つの点$\mathbf{z}^{s}$と$\mathbf{z}_{\mathrm{sym}}$は、2つの異なるドメインに対応するスタイル特徴表現でありながら、ドメイン$y, \hat{y}$に対応する要素以外は全て同一です。また、ドメイン$y, \hat{y}$に対応する要素では、符号が反転しているだけで絶対値は等しくなっています。そのため、2つの点$\mathbf{z}^{s}, \mathbf{z}_{\mathrm{sym}}$は、原点を中心として対称の位置に存在し、この2点の線形補間点は、必ずドメイン$y, \hat{y}$のいずれかに属することになります(上図(b, d))。より厳密には、完全な中点のみ、2つのドメインのどちらともいえない点となります(上図(c))。
こうして、$\mathbf{z}^s$を元に、対照的な$\mathbf{z}_{\mathrm{sym}}$を作成し、これらを使って2つのドメイン間の線形補間上の点を表すベクトルが得られることがわかりました。このスタイル補間ベクトルを用いることで、Generatorはドメイン補間的なスタイル表現から、Discriminatorを騙せる生成結果を出力できるよう訓練される、というわけです。
損失関数
最後に、損失関数についてざっくりと確認します。
スタイルに関するMultitask-Discriminator$D$と、コンテンツに関するDiscriminator$D_c$に関して、以下のような損失が使われます。
\begin{aligned}
\mathcal{L}_{D, D_{c}} &=\lambda_{\mathrm{adv}}^{\text {content }} \mathcal{L}_{\mathrm{adv}}^{\text {content }}+\lambda_{\mathrm{adv}}^{\text {domain }} \mathcal{L}_{\mathrm{adv}}^{\text {domain }} \\
&+\lambda_{\mathrm{adv}}^{\mathrm{rvs}} \mathcal{L}_{\mathrm{adv}}^{\mathrm{rvs}}+\lambda_{\mathrm{adv}}^{\text {interp }} \mathcal{L}_{\mathrm{adv}}^{\text {interp }}
\end{aligned}
また、2つのEncoder$E_c, E_a$とGenerator $G$については以下のような損失が使われます。
\begin{aligned}
\mathcal{L}_{G, E_{c}, E_{a}} &=-\left(\lambda_{\mathrm{adv}}^{\text {content }} \mathcal{L}_{\mathrm{adv}}^{\text {content }}+\lambda_{\mathrm{adv}}^{\text {domain }} \mathcal{L}_{\mathrm{adv}}^{\text {domain }}\right.\\
&\left.+\lambda_{\mathrm{adv}}^{\mathrm{rvs}} \mathcal{L}_{\mathrm{adv}}^{\mathrm{rvs}}+\lambda_{\mathrm{adv}}^{\text {interp }} \mathcal{L}_{\mathrm{adv}}^{\text {interp }}\right) \\
&+\lambda_{\mathrm{MMD}} \mathcal{L}_{\mathrm{MMD}}+\lambda_{\text {style }} \mathcal{L}_{\text {style }}+\lambda_{1}^{\mathrm{cc}} \mathcal{L}_{1}^{\mathrm{cc}} \\
&+\lambda_{1}^{\text {recon }} \mathcal{L}_{1}^{\text {recon }}+\lambda_{1}^{\text {latent }} \mathcal{L}_{1}^{\text {latent }}+\lambda_{\mathrm{ms}} \mathcal{L}_{\mathrm{ms}}
\end{aligned}
本当にこんなに大量の損失関数が必要なのか、という疑問は当然湧いてきますが、そのあたりはAbration Studyで検証されており、定性的にも定量的にもそれぞれ必要であるという実験結果が得られています。
以下、各損失に関する簡単な説明です。
- $\mathcal{L}_{\mathrm{adv}}^{\mathrm{content}}$: コンテンツ特徴にスタイルに関する情報が紛れ込まないようにするための損失。
- $\mathcal{L}_{\mathrm{adv}}^{\mathrm{domain}}$: $E_a(x_{\hat{y}})$や$\mathbf{z}^s$を用いて生成した画像が、指定したスタイルを反映しているか否かを表す損失。
- $\mathcal{L}_{\mathrm{adv}}^{\mathrm{rvs}}$: $\mathbf{z}_{\text{sym}}$を用いた生成結果が、指定したスタイルを反映しているか否かを表す損失。
- $\mathcal{L}_{\mathrm{adv}}^{\text {interp }}$: $\mathbf{z}^s$と$\mathbf{z}_{\text{sym}}$の補間点を用いて生成した結果が、指定したスタイルを反映しているか否かを表す損失。
- $\mathcal{L}_{\mathrm{MMD}}$: $E_a$の出力結果が、$\mathbf{z}^s$の分布に近づくようにするための損失。
- $\mathcal{L}{\text{style}}$: $E_a(x_{\hat{y}})$を用いて生成した画像と元のスタイル参照画像$x{\hat{y}}$とのスタイル特徴が近づくようにするための損失。ここでのスタイル特徴とは、訓練済みVGGの中間層のGram行列を意味する。
- $\mathcal{L}_{1}^{\mathrm{cc}}$: 一度別のドメインに変換された画像を、元のドメインに戻した時に、元画像と同一になるかを表すCycle-consistency損失。
- $\mathcal{L}_{1}^{\text {recon }}$: ある画像$x$から得られるコンテンツ特徴$E_c(x)$とスタイル特徴$E_a(x)$を用いて、元の画像を再構成できるかを表す再構成誤差。
- $\mathcal{L}_{1}^{\text {latent }}$: $\mathbf{z}^s$を用いて生成した画像から、$E_a$を通してスタイル特徴を抽出した時に、元の$\mathbf{z}^s$と近づくようにする損失。
- $\mathcal{L}_{\mathrm{ms}}$: モード崩壊を避けるためのMode Seeking3損失。元となるスタイル特徴に差異があれば、生成される画像でも差異が生じることを促す正則化。
まとめ
簡単にですが、ドメイン間の連続的な変換を実現するSAVI2Iについて紹介しました。SVAのアイディアはシンプルですが、スタイル特徴の表現方法として理にかなったものになっており、Sign-Symmetrical Attribute Vectorと組み合わせることでドメイン間補間という難しい課題に対するある程度の解決をもたらしています。
細かいネットワーク構造などは論文中にも記載がありますし、実装も公開されているので、気になる方はチェックしてみてください。
-
Isola, Phillip et al. “Image-to-Image Translation with Conditional Adversarial Networks.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5967-5976. ↩
-
Zhao, S., Song, J., Ermon, S.: InfoVAE: Information maximizing variational autoencoders. arXiv preprint arXiv:1706.02262 (2017) ↩
-
Mao, Q., Lee, H.Y., Tseng, H.Y., Ma, S., Yang, M.H.: Mode seeking generative adversarial networks for diverse image synthesis. In: CVPR (2019) ↩