LoginSignup
18
15

More than 1 year has passed since last update.

GANの訓練はGとD同時にできる! 訓練を33%以上高速化するOne Stage GAN(OSGAN)の紹介

Last updated at Posted at 2021-12-12

2021年のディープラーニング論文を1人で読むAdvent Calendar12日目の記事です。今回紹介するのはGANの訓練の高速化です。発想はとても単純で、GとDを同時に訓練して訓練ステップを1回で済ませてしまおうというものです。

著者は浙江大学、シンガポール国立大学、アリババなどからなるチームです。CVPR2021に採択されています。

これまでのGAN(Two Stage GAN:TSGAN)

これまでのGANはG(Generator)とD(Discriminator)を別々に訓練します。公式コードから改変して抜粋します。

    for i, (x, _) in enumerate(train_loader):
        batch_size = x.size(0)
        x = x.cuda()

        y_real = torch.ones(batch_size).cuda()
        y_fake = torch.zeros(batch_size).cuda()

        # ------------ Training Discriminator ------------
        z = torch.randn((batch_size, zdim, 1, 1)).cuda()
        fake = generator(z).detach()

        pred_real = discriminator(x)
        pred_fake = discriminator(fake)

        loss_real = F.binary_cross_entropy_with_logits(pred_real, y_real)
        loss_fake = F.binary_cross_entropy_with_logits(pred_fake, y_fake)

        loss_d = loss_real + loss_fake

        optimizer_d.zero_grad()
        loss_d.backward() # DのBackprop
        optimizer_d.step()

        # ------------ Training Generator ------------
        z = torch.randn((batch_size, zdim, 1, 1)).cuda()
        fake = generator(z)
        pred_fake = discriminator(fake)

        loss_g = F.binary_cross_entropy_with_logits(pred_fake, y_real)

        optimizer_g.zero_grad()
        loss_g.backward() # GのBackprop
        optimizer_g.step()

このように「DのBackprop(微分)」「GのBackprop」と1つのバッチに対して2回微分を行います。2回微分のステージがあるので「Two Stage GAN(TSGAN)」と、この論文では呼んでいます。

このBackpropを「DとGあわせて1回でやればいいじゃん」って言っているのがこの論文です。そんなことできるのでしょうか。

1回のBackpropで訓練するGAN(One Stage GAN:OSGAN)

こちらも公式コードから改変して抜粋です。

    scaler = GradientScaler.apply
    for i, (x, _) in enumerate(train_loader):
        batch_size = x.size(0)
        x = x.cuda()

        y_real = torch.ones(batch_size).cuda()
        y_fake = torch.zeros(batch_size).cuda()

        z = torch.randn((batch_size, zdim, 1, 1)).cuda()
        fake = generator(z)

        fake_neg = scaler(fake) # 勾配のスケール

        pred_real = discriminator(x)
        pred_fake = discriminator(fake_neg)

        loss_real = F.binary_cross_entropy_with_logits(pred_real, y_real)
        loss_fake = F.binary_cross_entropy_with_logits(pred_fake, y_fake, reduction='none') 
        loss_d = loss_real + torch.mean(loss_fake)

        # -log(D(G(z))), not included in loss_d, confuse fake image to real
        loss_g = F.binary_cross_entropy_with_logits(pred_fake, y_real, reduction='none')

        #-- 勾配をスケールするための定数を計算
        gamma = get_gradient_ratios(loss_g, loss_fake, pred_fake)

        grad_d_factor = 1.0 / (1.0 - gamma)

        loss_pack_fake = loss_fake - loss_g
        scaled_loss_pack_fake = loss_pack_fake * grad_d_factor
        loss_pack = loss_real + torch.mean(scaled_loss_pack_fake)

        GradientScaler.factor = gamma # スケール定数の更新
        #--

        optimizer_d.zero_grad()
        optimizer_g.zero_grad()

        loss_pack.backward() # Backpropは1回だけ

        optimizer_d.step()
        optimizer_g.step()

初見だとなにやっているのかわかりませんが、注目すべきは.backward()が1回しか出てこないことです。これが本論文の提唱手法で、1つのステージで訓練しているのでOne Stage GAN(OSGAN)と呼んでいます。

コードのgamma=の部分はGのロスの勾配を、Dのロス勾配で表すための定数を計算しています。「何らかの定数を計算して、DとGのロスをくっつけて1回のBackpropで回すんだな」ぐらに思っておいてください。

「get_gradient_ratiosで実は微分計算してるんでしょ」と思うかもしれませんが、実はそんなことなくて、

def get_gradient_ratios(lossA, lossB, x_f, eps=1e-6):
    grad_lossA_xf = torch.autograd.grad(torch.sum(lossA), x_f, retain_graph=True)[0]
    grad_lossB_xf = torch.autograd.grad(torch.sum(lossB), x_f, retain_graph=True)[0]
    gamma = grad_lossA_xf / grad_lossB_xf

    return gamma

と、最終層の勾配比を計算しているだけです。2ステージGANの場合は、ネットワーク全体をフルに2回微分計算する必要がありますが、最終層の微分計算して定数求めるだと圧倒的に計算量が少ないです。これがOne Stage GANの計算量削減のキモです。

でも1ステージでやったら生成画像のクォリティ落ちるんでしょ?

いえいえ、そんなことはありません。

12_01.png

各データセットやモデルで、1ステージ、2ステージで訓練したときの学習曲線です。左軸がFID、右軸がlog(KID)でいずれも低いほうが良いです。

どのデータセットやモデルに対しても、単純に高速化しているだけで特に性能は損なっていないのがわかります。高速化の割合は実測で1.4~1.7倍。2回を1回に落としているのでかなり高速化できています。性能に関してはむしろ1ステージのほうが安定してそうな雰囲気さえもあります。定量評価だと、

12_02.png

と、1ステージが2ステージと同等の性能を出しているどころか、むしろ1ステージのほうが大半のケースで良いということがわかります。

1ステージのほうが2ステージより、FIDが優れている理由について、論文では「2ステージでは、Gの学習中に以前の最適化情報が失われ、非効率的な最適化が行われる。GとDを同時に更新すれば、この問題を効率的に回避できるから」と考察しています。

対称GAN(Symmetric GANs)と非対称GANs(Asymmetric GANs)

定義

「なぜこのような単純なスケーリングでうまくいくのか」ということが気になります。本論文ではこの理論的な導出が大半を占めています。

本論文ではGANの損失関数に応じて、対称GAN(Symmetric GANs)非対称GANs(Asymmetric GANs)と2つの場合分けをしています。対称、非対称限らず、多くのGANでは以下のような損失関数をしています。

\begin{align*}\mathcal{L}_\mathcal{D}(x, \hat{x})=\mathcal{L}_\mathcal{D}^r(x)+\mathcal{L}_\mathcal{D}^f(\hat{x}) &\\ \mathcal{L}_\mathcal{G}(\hat{x})=\mathcal{L}_\mathcal{G}(\mathcal{G}(z))\end{align*}

1つ目がDiscriminatorを訓練するときのロス、2つ目がGeneratorを訓練するときのロスです。ここで、$x$は本物のデータ、$\hat{x}$がノイズ$z$から生成した偽物のデータになります。$\mathcal{L}_\mathcal{D}^r, \mathcal{L}_\mathcal{D}^f$はそれぞれ本物(real)、偽物(fake)をDで見分けたときのロスになります。

ここで、対称GANと非対称GANの区分をします。

  • 対称GAN:$\mathcal{L}_\mathcal{G}=-\mathcal{L}_\mathcal{D}^f$を満たす
  • 非対称GAN:これを満たさない

DCGANの場合

もう少し具体的に見てみましょう。例えばDCGANでも対称GANの場合と、非対称GANの場合があります。DCGANの場合は、

\min_{\mathcal{G}}\max_{\mathcal{D}}\mathbb{E}_{x\sim p_d}\bigl[\log\mathcal{D}(x)\bigr]+\mathbb{E}_{z\sim p_z}\bigl[\log(1-\mathcal{D}(\mathcal{G}(z)))\bigr]

という損失関数において、

\begin{align*}\mathcal{L}_\mathcal{D}=-\log\mathcal{D}(x)-\log(1-\mathcal{D}(\mathcal{G(z)})) &\\ \mathcal{L}_\mathcal{G}=\log(1-\mathcal{D(\mathcal{G}(z))})\end{align*}

とするのは、$\mathcal{L}_\mathcal{G}=-\mathcal{L}_\mathcal{D}^f$を満たすため、対称GANとなります。

ところがこの損失関数は、Gで勾配消失を起こすため、以下のような飽和しないための損失関数とするのが一般的です。

\begin{align*}\mathcal{L}_\mathcal{D}=-\log\mathcal{D}(x)-\log(1-\mathcal{D}(\mathcal{G(z)})) &\\ \mathcal{L}_\mathcal{G}=-\log(\mathcal{D(\mathcal{G}(z))})\end{align*}

この場合は対称GANの定義を満たさないので、非対称GANとなります。

対称GANの場合1ステージ化は容易

対称GANの場合は、$\mathcal{L}_\mathcal{G}=-\mathcal{L}_\mathcal{D}^f$を満たすため、偽のサンプル$\hat{x}$についての勾配$\nabla_{\hat{x}}\mathcal{L}_\mathcal{G}=-\nabla_{\hat{x}}\mathcal{L}_\mathcal{D}^f$も同様の関係で表されます。対称GANの場合は、GパラメーターのアップデートをDのロスでできるということを意味します。

実装的には、こちらのコードより

    for i, (x, _) in enumerate(train_loader):
        batch_size = x.size(0)
        x = x.cuda()

        y_real = torch.ones(batch_size).cuda()
        y_fake = torch.zeros(batch_size).cuda()

        z = torch.randn((batch_size, zdim, 1, 1)).cuda()
        fake = generator(z)

        pred_real = discriminator(x)
        pred_fake = discriminator(fake)

        loss_real = F.binary_cross_entropy_with_logits(pred_real, y_real)
        loss_fake = F.binary_cross_entropy_with_logits(pred_fake, y_fake)

        loss_adv = loss_real + loss_fake

        optimizer_g.zero_grad()
        optimizer_d.zero_grad()

        loss_adv.backward()

        optimizer_g.step()
        optimizer_d.step()

Dのロスだけで訓練すればOKです。

非対称GANの場合は、GのパラメーターをDのロスで単純にアップデートできないため、工夫が必要になります。

対称GAN・非対称GANの具体例

論文の補助資料にはGANの損失関数ごとに、対称か非対称かの具体例が示されていました。

12_03.png

「sym」が対称GAN、「asym」は非対称GANを示します。最近のGANでよく使われるHingeロスはGeoGANの派生です。実践的にはほとんどが非対称GANと捉えておけばいいです。

非対称GANの1ステージ化

結論としては冒頭で示したコードでOKなのですが、非対称GANの1ステージ化には理論的な補助が必要となります。難しかったら飛ばしても構わないです。興味のある方だけ読んでください。

1ステージ化した損失関数を、

\mathcal{L}=\mathcal{L}_\mathcal{D}-\mathcal{L}_\mathcal{G}=\mathcal{L}_\mathcal{D}^r+\mathcal{L}_\mathcal{D}^f-\mathcal{L}_\mathcal{G}

で考えます。マイナスとしているのは$\mathcal{L}_\mathcal{D}^f$と$\mathcal{L}_\mathcal{G}$の間での勾配の衝突を避けるためです。ここで、偽のサンプルに関するロスの項を$\mathcal{L}_f$とし、

\mathcal{L}_f=\mathcal{L}_\mathcal{D}^f-\mathcal{L}_\mathcal{G}

に注目します。ここでの目標は、$\nabla_{\hat{x}}\mathcal{L}_f$から$\nabla_{\hat{x}}\mathcal{L}_{\mathcal{G}}$をどうやって取り出すか、つまり、$\nabla_{\hat{x}}\mathcal{L}_{\mathcal{G}}$と$\nabla_{\hat{x}}\mathcal{L}_{\mathcal{D}}$の関係式を得たいのです。

Backpropの類型化

関係式を得るためには、Backpropを類型化して考える必要があります。ニューラルネットワークの$l$番目から$l-1$番目のレイヤーについてのBackpropは、

$$\nabla_{x^{l-1}}\mathcal{L}=\mathcal{P}\cdot\mathcal{F}(\nabla_{x^l}\mathcal{L})\cdot\mathcal{Q}\tag{1}$$

で表されます。ここで$\mathcal{F}$については、

$$\mathcal{F}(y_1+y_2) = \mathcal{F}(y_1)+\mathcal{F}(y_2)\tag{2}$$

という関係式を満たします(ただし、これはBatchNormに関しては満たしません)。

具体的なニューラルネットワークのレイヤーに対して、$\mathcal{P, F, Q}$がそれぞれどんな値になるかというと以下の通りです。

12_04.png

$\mathbf{I}$は恒等変換(行列)を表します。実際に使うレイヤーの多くが式(1)で類型化できることが確認できます。

偽のサンプルに関して、最終層のGとDの勾配の比を考えれば良い

式(2)から$y_1=y_2=0$とおくと、

12_05.png

という式が得られます。この式より、微積を行い、

12_06.png

という式が得られます。ここでポイントなのは、$\mathcal{F}(\cdot)$同士の割り算は、そのまま中身の割り算に置き換えられるということです。式(1)よりG, Dの偽のサンプルについてのBackpropは、

\nabla_{x^{l-1}}\mathcal{L}_\mathcal{G}=\mathcal{P}\cdot\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{G})\cdot\mathcal{Q} \\ \nabla_{x^{l-1}}\mathcal{L}_\mathcal{D}=\mathcal{P}\cdot\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{D})\cdot\mathcal{Q}

で表されます。GとDの勾配の比を取ると、

\frac{\nabla_{x^{l-1}}\mathcal{L}_\mathcal{G}}{\nabla_{x^{l-1}}\mathcal{L}_\mathcal{D}}=\frac{\mathcal{P}\cdot\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{G})\cdot\mathcal{Q}}{\mathcal{P}\cdot\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{D})\cdot\mathcal{Q}}=\frac{\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{G})}{\mathcal{F}(\nabla_{x^l}\mathcal{L}_\mathcal{D})}=\frac{\nabla_{x^l}\mathcal{L}_\mathcal{G}}{\nabla_{x^l}\mathcal{L}_\mathcal{D}}

これはGとDの勾配の比率が、全レイヤーを通して一定であることを意味します。すなわち、$\nabla_{x^{l-1}}$という特定のレイヤーについてだけでなく、$i$番目の偽のサンプル$\hat{x}$についての係数を$\gamma_i$とすれば、

\frac{\nabla_{\hat{x}_i}\mathcal{L}_\mathcal{G}}{\nabla_{\hat{x}_i}\mathcal{L}_\mathcal{D}}=\cdots=\frac{\nabla_{\hat{x}_i^l}\mathcal{L}_\mathcal{G}}{\nabla_{\hat{x}_i^l}\mathcal{L}_\mathcal{D}}=\cdots=\frac{\nabla_{\hat{x}_i^L}\mathcal{L}_\mathcal{G}}{\nabla_{\hat{x}_i^L}\mathcal{L}_\mathcal{D}}=\gamma_i\tag{3}

という式で表されます。これは偽のサンプルについて最終層のGとDの勾配の比を取れば、$\nabla_{\hat{x}}\mathcal{L}_{\mathcal{G}}$と$\nabla_{\hat{x}}\mathcal{L}_{\mathcal{D}}$の関係式が得られることを意味します。もしこの値が得られれば、対称GANと同様にDのロスでGをアップデートできるというわけです。

$\mathcal{L}_f=\mathcal{L}_\mathcal{D}^f-\mathcal{L}_\mathcal{G}$でしたので、

\nabla_{\hat{x}_i^l}\mathcal{L}_f = \nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{D}}^f-\nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{G}}

これに(3)式を適用すれば、

\nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{D}}^f=\frac{1}{1-\gamma_i}\nabla_{\hat{x}_i^l}\mathcal{L}_f \\ \nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{G}}=\frac{\gamma_i}{1-\gamma_i}\nabla_{\hat{x}_i^l}\mathcal{L}_f

という式で表されます。ここで新たな損失項$\mathcal{L}_\mathcal{D}^{ins}, \mathcal{L}_\mathcal{G}^{ins}$を考えます。insはインスタンス単位の損失関数を意味します。

\begin{align*}\mathcal{L}_\mathcal{D}^{ins}=\mathcal{L}_\mathcal{D}^r+\frac{1}{1-\gamma_i}\bigl(\mathcal{L}_\mathcal{D}^f-\mathcal{L}_\mathcal{G}\bigr) &\\ \mathcal{L}_\mathcal{G}^{ins}=\frac{\gamma_i}{1-\gamma_i}\bigl(\mathcal{L}_\mathcal{D}^f-\mathcal{L}_\mathcal{G}\bigr)\end{align*}

となります。これは対称GANと同じ形なので、非対称GANも対称GANのように1ステージで訓練することが可能となりました。

ただし、

\nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{G}}^{ins}=\gamma_i\cdot \nabla_{\hat{x}_i^l}\mathcal{L}_{\mathcal{D}}^{ins}

とBackprop時に勾配のスケーリングを行う必要があります。

1ステージ化の実装

擬似コードで書けば非対称GANの1ステージ化は次のようになります。

12_07.png

ロスの計算部分以降を冒頭に示したコードでもう一度見てみましょう。

        y_real = torch.ones(batch_size).cuda()
        y_fake = torch.zeros(batch_size).cuda()

        z = torch.randn((batch_size, zdim, 1, 1)).cuda()
        fake = generator(z)

        fake_neg = scaler(fake) # 勾配のスケール

        pred_real = discriminator(x)
        pred_fake = discriminator(fake_neg)

        loss_real = F.binary_cross_entropy_with_logits(pred_real, y_real)
        loss_fake = F.binary_cross_entropy_with_logits(pred_fake, y_fake, reduction='none') 
        loss_d = loss_real + torch.mean(loss_fake)

        # -log(D(G(z))), not included in loss_d, confuse fake image to real
        loss_g = F.binary_cross_entropy_with_logits(pred_fake, y_real, reduction='none')

        #-- 勾配をスケールするための定数を計算
        gamma = get_gradient_ratios(loss_g, loss_fake, pred_fake)

        grad_d_factor = 1.0 / (1.0 - gamma)

        loss_pack_fake = loss_fake - loss_g
        scaled_loss_pack_fake = loss_pack_fake * grad_d_factor
        loss_pack = loss_real + torch.mean(scaled_loss_pack_fake)

        GradientScaler.factor = gamma # スケール定数の更新
        #--

        optimizer_d.zero_grad()
        optimizer_g.zero_grad()

        loss_pack.backward() # Backpropは1回だけ

        optimizer_d.step()
        optimizer_g.step()

このコードでloss_pack_fakeが$\mathcal{L}_f$にあたります。これを$\frac{1}{1-\gamma}$でスケーリングしているので、$\mathcal{L}_\mathcal{D}^{ins}$の部分をBackpropで回していることがわかります。2ステージのときのDのロスだけでBackpropを回すというのは、対称GANのときも同じでした。

非対称GANで追加されているのはfake_neg = scaler(fake)GradientScaler.factor = gammaという勾配のスケーリングです。勾配のスケーリングは前者でやっています。このGradientScalerがどういう実装になっているかというと単純で、

class GradientScaler(torch.autograd.Function):
    factor = 1.0

    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        factor = GradientScaler.factor
        return factor.view(-1, 1, 1, 1)*grad_output

Forepropは入力をそのまま、Backpropのときのみ勾配をスケールするという実装です。いま、fake_neg = scaler(fake)をGによる画像生成とDによる識別の間に入れました。これは、G側の更新時に勾配はスケールされるが、D側の更新時には勾配はスケールされないことを意味します。

擬似コードの9行目では、$\nabla_{\bar{x}_i}\mathcal{L}_\mathcal{G}^{ins}=\gamma_i\cdot\nabla_{\bar{x}_i}\mathcal{L}_{\mathcal{D}}^{ins}$で更新するとあったため、この実装は擬似コードの9行目にちょうどマッチします。

これで非対称GANを1ステージ化することができました。

訓練が33%以上高速化される理由

1ステージ化による計算量について見てみましょう。

12_08.png

左が2ステージのGAN、右が1ステージのGANとします。ここでForepropとBackpropの計算量を同じと仮定し、Gの計算量を$\mathcal{T}^g(z)$、Dの計算量を$\mathcal{T}^d(x)$とします。理論的な計算量は、

  • 2ステージの場合:$4\mathcal{T}^g(z)+8\mathcal{T}^d(x)$
  • 1ステージの場合:$3\mathcal{T}^g(z)+6\mathcal{T}^d(x)$

であることがわかります。2ステージの場合を1ステージの場合で割ると、$4/3$となるため、最悪計算量で33%以上高速化されるということがわかります。

蒸留でも有効

敵対的な訓練はGANだけでなく、蒸留(Distillation)でもあります。教師をResNet34、生徒をResNet18とします。CIFAR-10、CIFAR-100の蒸留を行ったところ、既存の研究よりも良い結果が得られたとのことです。

12_09.png

1ステージ化するメリットは単なる速度面だけではないことがわかります。

まとめと感想

この論文では、GANのDとGを同時に更新するというGANの1ステージ化について提唱しています。GANの計算量はかなり膨大で、実際に自分もかなり悩んだことがあったので、理論的には33%、実践的にはほぼ1.5倍高速化できるというのはかなりありがたい研究です。理論展開が若干ややこしいですが、証明がかなり丁寧な論文なので、ゆっくり読んでいけば理解できる内容となっています。

個人的な疑問は、Gに敵対的なロス以外(例えばL1ロス)がついていた場合どう訓練するのかという点です。Image to image translationではこういったロスがよく出てきます。GのロスをDのロスに統合して、Backpropで回してGの勾配をスケールするという内容なので、ちゃんと考えればできるのではないかと思います(すぐに他の研究が出てきそう)。

これまでのアドベントカレンダーで、StyleGAN2の有用性はかなり示されてきたので、例えばこの研究を使って「StyleGAN2の訓練を1.5倍高速化できる」というとかなりインパクトが大きいのではないかなと思います。コードにするとそこまで難しくないですし、簡単なGANでぜひ実装してみたい研究です。

告知

このアドベントカレンダーが本になりました!
https://koshian2.booth.pm/items/3595424
Amazonでも扱いあります詳しくは👉 https://shikoan.com

18
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
15