0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

GANの訓練ループでよく見るdetachは何をしている?

Last updated at Posted at 2022-07-31

どんな記事?

以下は"Inside Deep Learning"という本に載っているGANの訓練ループの一部です.(見やすいように一部編集しています)

optimizerD = torch.optim.AdamW(D.parameters())
optimizerG = torch.optim.AdamW(G.parameters())

G_losses = []
D_losses = []

for epoch in tqdm(range(num_epochs)):
    for data, class_label in tqdm(train_loader, leave=False):
        real_data = data.to(device)
        y_real = torch.full((batch_size,1), real_label, dtype=torch.float32, device=device)
        y_fake = torch.full((batch_size,1), fake_label, dtype=torch.float32, device=device)
        
        # Discriminatorの訓練
        # readデータで勾配計算
        D.zero_grad()
        errD_real = loss_func(D(real_data), y_real)
        errD_real.backward()
        
        # fakeデータで勾配計算
        z = torch.randn(batch_size, latent_d, device=device)
        fake = G(z)
        # 元のNotebookにあったコメント↓
        #Why do we detach here? Because we don't want the gradient to impact G. 
        #Our goal right now is to update _just_ the discriminator. 
        #BUT, we will re-use this fake data for updating the discriminator, so we want to save the 
        #non-detached version! 
        errD_fake = loss_func(D(fake.detach()), y_fake)
        errD_fake.backward()

        errD = errD_real + errD_fake
        optimizerD.step()

        # Generatorの訓練
        G.zero_grad()
        errG = loss_func(D(fake), y_real)
        errG.backward()
        optimizerG.step()
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())

ちょっと気になるのはDiscriminatorの訓練のerrD_fake = loss_func(D(fake.detach()), y_fake)です.どうしてここでfake.detach()するんだ?どうしてGeneratorの訓練ではfake.detach()しないんだ?"Inside Deep Learning"の本文中とNotebookのコメントにこのdetach()に関する説明がありましたが,納得がいきません(納得いかない理由は後ほど).

このfake.detach()が出現するのはこの本特有のことではなく,他のところでもよく出てきます(例えばPyTorchのexample).

この記事では,このfake.detach()について調べます.

結論

This is not true. Detaching fake from the graph is necessary to avoid forward-passing the noise through G when we actually update the generator. If we do not detach, then, although fake is not needed for gradient update of D, it will still be added to the computational graph and as a consequence of backward pass which clears all the variables in the graph (retain_graph=False by default), fake won't be available when G is updated.

つまり,「fakeをDとGの訓練で使いまわしたいが,detachしないとDの勾配を計算したときに計算グラフが消され,Gの勾配計算ができなくなってしまう」ということらしいです.

もう少し細かく説明すると,

  • fakeはDiscriminatorとGeneratorの訓練の両方で使うので,Discriminatorで計算したfake = G(z)をGeneratorの訓練でも再利用したい.
  • Dの訓練でdetachせずにfakeを使って損失を計算(errD_fake = loss_func(D(fake.detach()), y_fake))し,Dの勾配を計算(errD_fake.backward())すると,計算グラフが消されてしまう.
  • 計算グラフが消された状態で,Gの訓練時に先ほどのfakeを再利用してGの損失を計算(errG = loss_func(D(fake), y_real))し,Gの勾配を計算(errG.backward())しようとする.しかしfakeの計算グラフはDの勾配を計算したときに消えているため,Gの勾配を計算することができない(エラーを吐く).

つまり,detachする以外の方法として

  • DとGの訓練でfakeを使い回すのをやめて,毎回fake = G(z)を計算する.
  • errD_fake.backward(retain_graph=True)として計算グラフを消さないようにする.(TORCH.TENSOR.BACKWARD

などが考えられます(効率が良いかは別).

自分の認識が合っているか少し自信がないので,もし間違っていらたご指摘ください.

以下蛇足です.

コメントに対する疑問

先ほどのコードには,fake.detach()に関してこんなコメントがありました.

Why do we detach here? Because we don't want the gradient to impact G. Our goal right now is to update just the discriminator. BUT, we will re-use this fake data for updating the discriminator, so we want to save the non-detached version!

なるほど.この説明はおそらく次のような意味でしょう.

「DiscriminatorDのfakeデータに対するLosserrD_fakeの計算は,もしもdetachしないのであれば,errD_fake = loss_func(D(fake), y_fake)となります.そもそもfakeはlatent vectorzをGeneratorGに渡すことで生成されていました(fake = G(z)).そのため,勾配を計算(errD_fake.backward())すると,GとDの両方の重みに対して勾配が計算されてしまいます.しかし今はDを訓練しているので,Gの重みは更新したくありません.そこでfake.detach()とすることによってGの勾配の計算を行わないようにしています」

一瞬納得しそうになりましたが,よく考えてみるとちょっと納得いきません.

optimizerD = torch.optim.AdamW(D.parameters()); optimizerG = torch.optim.AdamW(G.parameters())としているなら,Dの訓練ではDの重みだけ,Gの訓練ではGの重みだけを更新することができているのではないか?

考えてもわからないのでググったらStackOverflowに全く同じ疑問を持っている人がいました.

that's because if you don't use fake.detach() in output = netD(fake.detach()).view(-1) then fake is just some middle variable in the whole computational Graph, which tracks gradients in both netG and netD. and when you call netD.backward() the graph will be released. which means no more gradient information about netG() in the computational Graph. then when you use errG.backward() later, it will cause an error something like
"Trying to backward through the graph a second time"
if you don't use fake.detach(), you can use netD.backward(retain_graph=True)

GitHubにも同じような質問がありました.

This is not true. Detaching fake from the graph is necessary to avoid forward-passing the noise through G when we actually update the generator. If we do not detach, then, although fake is not needed for gradient update of D, it will still be added to the computational graph and as a consequence of backward pass which clears all the variables in the graph (retain_graph=False by default), fake won't be available when G is updated.

つまり,「結論」で書いたような感じらしいです.

Inside Deep Learning

detachの説明では混乱しましたが,この本は説明が超丁寧でわかりやすく,僕のように雰囲気でDeep Learningをやっている人にもおすすめです.

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?