CycleGANは画像の文脈を保ったまま画像変換を行うことができる手法です(以前に書いたCycleGANの解説記事)。しかし、形状の変化を伴うような画像変換は苦手としていました。
これに対処するため、物体のマスク情報を明示的に使って画像だけでなくマスクも変換しようというInstaGANが提案されています。画像に対してインスタンスごとにマスクが与えられた状態のデータを必要としてますが、面白い手法なので解説したいと思います。
#InstaGAN
InstaGAN: Instance-aware Image-to-Image Translation
著者の実装
また、著者の実装はかなり煩雑なのでこちらに再現実装を置いておきました。
記号は以下の意味で用います。
$D_{X/Y}$: ドメイン$X/Y$についてのDiscriminator
$G_{XY/YX}$: ドメイン$X→Y/Y→X$の変換を行うGenerator
$(x, a)$: ドメイン$X$の画像とマスクの組
$(y, b)$: ドメイン$Y$の画像とマスクの組
$(x', a')$: $G_{YX}$により生成されたドメイン$X$の画像とマスクの組
$(y', b')$: $G_{XY}$により生成されたドメイン$Y$の画像とマスクの組
通常のCycleGANでは2つのGenerator $G_{XY/YX}$はそれぞれドメイン$X/Y$の画像を入力としてドメイン$Y/X$へと変換します。InstaGANでは画像だけでなく画像とマスクの組に対して変換を試みます。
2つのgenerator $G_{XY}/G_{YX}$はそれぞれドメイン$X/Y$の画像とマスクの組をそれぞれ$Y/X$ドメインに変換します。一方で、Discriminator $D_{X}/D_{Y}$はそれぞれ与えられた画像が元からドメイン$X/Y$に属しているものかGeneratorによって変換されたものかを判別します。
Generator
Generatorは4つのモジュール(括弧内は画像中での表記)からなります。
- Image Encoder ($f_{GX}$)
- Mask Encoder ($f_{GA}$)
- Image Decoder ($g_{GX}$)
- Mask Decoder ($g_{GA}$)
Image EncoderとMask Encoderはそれぞれ画像とインスタンスごとのマスクから特徴を抽出します。
各インスタンスのマスク特徴の総和を取ったもの(総和マスク特徴、斜線)を取っておきます。
Image Decoderは画像特徴と総和マスク特徴を使って変換後の画像を生成します。
Mask Decoderはそれらに加えて各インスタンスのマスク特徴を使ってそれぞれ対応する変換後のマスクを生成します。
Discriminator
Discriminatorは
- Image Encoder ($f_{GX}$)
- Mask Encoder ($f_{GA}$)
- Head ($g_{GX}$)
からなります。
2つのEncoderはGeneratorのものと同様に画像と各インスタンスのマスクから特徴を抽出します。
Head部分はDiscriminatorの最終的な出力を行うモジュールです。
損失関数
InstaGANのGeneratorの損失はAdversarial Loss ($\mathcal{L}_{LSGAN}$)、Cycle Consistency Loss ($\mathcal{L}_{cyc}$)、Identity Loss ($\mathcal{L}_{idt}$)、Context Preserving Loss ($\mathcal{L}_{ctx}$) の4つの項からなります。
Adversarial Loss
Adversarial LossはGeneratorの生成した画像がDiscriminatorを騙せた度合いに関する損失です。論文中ではLSGANを使用しています。
$ \mathcal{L}_{LSGAN} = (D_{X}(x, a) - 1)^2 + (D_{X}(G_{YX}(y, b))^2 + (D_{Y}(y, b) - 1)^2 + (D_{Y}(G_{XY}(x, a))^2$
Cycle Consistency Loss
Cycle Consistency Lossは変換後の画像が元の画像の文脈を残せているかに関する損失で、CycleGANの論文中で提案されています。Cycle Consistency Lossは図のように$X→Y→X$のような変換を行うと元の画像に戻ってくるようGeneratorに制約を課しており、これにより元の画像の文脈を維持した変換方法を学習することが期待できます。
$\mathcal{L}_{cyc} = || G_{XY}(G_{YX}(x, a)) - (x, a) ||_{1} + || G_{YX}(G_{XY}(y, b)) - (y, b) ||_{1}$
Identity Loss
2つのGenerator $G_{XY/YX}$ はそれぞれ$X→Y, Y→X$の変換だけをしてほしい、言い換えれば画像の対象ドメイン以外の領域には変化を加えないことが望ましいです。そこで、$G_{XY/YX}$に変換対象ではないドメイン$Y/X$の画像を与えても何も変化させないよう制約を与えるための損失です。これもCycleGANで使用されていたものです。
$\mathcal{L}_{idt} = || G_{XY}(y, b) - (y, b) ||_{1} + || G_{YX}(x, a) - (x, a) ||_{1} $
Context Preserving Loss
インスタンスマスク以外の領域に極力変化させないような制約を与える損失です。
$\mathcal{L}_{ctx} = || w(a,b') \bigodot (x - y') ||_1 + || w(b,a') \bigodot (y - x') ||_1$
$w(a,b')$および$w(b,a')$はそれぞれ変換前後でのマスクの合併を除いた領域、つまり変換前でも変換後でも背景となっている領域を示しています。$x - y', y - x'$は変換前後の画像の差分。 $\bigodot$は要素積を表しており、全体として変換対象となっているマスク領域以外を極力変化させないような制約を与えています。
Total Loss
$\mathcal{L}_{InstaGAN} = \mathcal{L}_{LSGAN} + \lambda_{cyc}\mathcal{L}_{cyc} + \lambda_{idt}\mathcal{L}_{idt} + \lambda_{ctx}\mathcal{L}_{ctx}$
$\lambda_{cyc}, \lambda_{idt}, \lambda_{ctx}$はそれぞれの項の重みを調整する係数です。
Sequential Mini-batch Translation
InstaGANではMask EncoderとMask Decoderは個々のインスタンスごとにマスクの変換を行います。これはその画像中に変換対象となるインスタンスが少数しかない場合は上手くいきますが、数が増えるとメモリの使用量が増えていってしまいます。それを解決するための手法としてSequential Mini-batch Translationを提案しています。
Sequential Mini-batch Translationを使う場合、Generatorは与えられたインスタンスを一度に全て変換するのではなく、数個ごとに変換を行います。学習時の誤差逆伝播は実線の部分だけで行います。iterationをまたぐ点線部分はdetachされており誤差情報が伝わりません。
Result
InstaGANではヒツジとキリンやカップとボトルのような形状変化を伴う変換が上手くできています。また、スカートとパンツの変換などもできており仮想試着のようなことができています。