10
15

More than 1 year has passed since last update.

【PyTorch】VAE + Wasserstein-GAN

Last updated at Posted at 2021-03-17

今回はいつも以上にPyTorchよくわからん!けどなんとか動かす!感が満載です。

目次

飛べます
今回はVAE + GAN
GAN
 - Wasserstein-Gan
 - 識別器の学習
 - 生成器の学習
 - 超ポイント1
 - 超じゃないけどまぁまぁポイント
実装
 - パラメータと最適化関数の設定
 - FeedForward
 - 目的関数
 - 識別器の学習
 - VAEの学習
 - 敵対学習
超ポイント2
超ポイント3
動かしてみた

今回はVAE + GAN

皆大好きVariational AutoEncoder(VAE)は(というか生成モデルは)、
出力データのバラツキが小さいという「過剰平滑化(Over Smoothing)」が欠点で上げられます。


Q.これが起きるとどうなるの?
A.画像ならボヤけた感じに、音声ならブザーっぽい音になる


僕の小さなツルツルの脳みそで理解してるところによると、
過剰平滑化が怒ってしまう理由は、
出力に「最もそれらしい値(期待値)」を吐くので似たような値になってしまうからです。

その対応策として、音声分野だと
Gloval Variance(GV)と呼ばれる特徴量の分散を広げる手法や、
Moduration Spectrum(MS)と呼ばれる特徴量のスペクトルの振幅を補償する手法があります。

そして皆大好きなGenerative-Adversarial-Nets(GAN)もそのうちの1つとして利用できます。
オリジナルの特徴量とモデルが出力した特徴量の散布図を比較したやつを見たことありませんか?
オリジナルは散らばってるのに出力は1箇所に密集してる〜っみたいなやつです。
GANを組み込むとそれが良い感じに散らばります。

いろんな制約や使い方をすればGANはもっと他にも恩恵をもたらしてくれるでしょうね。

GAN

生成器と識別器の2つを交互に学習するネットワークです。以上です。本当に以上です。

色々なGANがありますが、今回扱うGANに関しては、次の意気込みを持ってます。
① 生成器(VAEのデコーダー)・・・潜在特徴量からちゃんと特徴量を推定するぞ!
② 識別器・・・生成器で生成したデータなのか、本物データなのか見極めるぞ!

つまり、理論的に言い換えると
① 生成器(VAEのデコーダー)・・・潜在特徴量の分布を推定する、潜在特徴量から入力特徴量を復元する、ように学習
② 識別器・・・偽物なら小さい値、本物なら大きい値を返すよう学習する


このとき、生成器の意気込みに、
”識別器が「本物データだ」と勘違いするように生成するぞ!(識別器が大きな値を吐くように生成するぞ)”
というのを足して上げて、生成器と識別器を交互に学習するだけです。

Wasserstein-GAN

識別器にも色々種類があります。今回はWasserstein距離という尺度を使って、「本物?VAEが復元した偽物?」を測ります。
Wasserstein距離の説明は僕からはしません。

この識別器はDNNとします。

一応、本記事で、W-GANに関する簡単な説明はします。それ抜きにコードだけ見ても「?」だと思うので。

説明

\begin{align}
\hat{\boldsymbol{x}} &= G_\theta(\boldsymbol{z}) \\
d &= D_\phi(\boldsymbol{x})
\end{align}
  • $G_\theta \cdots \hspace{5pt}$生成器。潜在特徴量など、ある特徴量 $\boldsymbol{z}$ から、特徴量 $\hat{\boldsymbol{x}}$ を生成する。
  • $D_\phi \cdots \hspace{5pt}$識別器。入力特徴量がオリジナルか生成器で生成されたものか識別。

image.png
貧相ながらアルゴリズムを↑
$\Large{\style{background-color:yellow;}{エンコーダパラメータ\psi に関する更新式直してます!}}$

識別器の学習

識別器のパラメータ $\phi$ の最適化は、目的関数$L_{critic}$を最大化するよう行います。
$L_{critic}$の第1項はオリジナルの特徴量を識別器にかけた値で、
第2項は生成器で生成した特徴量を識別器にかけた値です。
前者は大きな値を、後者は小さな値を取るように学習するわけです。


### 生成器の学習 生成器のパラメータ$\theta$の最適化は、目的関数$L$を最大化するよう行います。 $L_{gen}$は生成器自体の学習に用いる目的関数です。$L$の第2項は、生成した特徴量を識別器にかけた値です。 先ほども述べましたが、GANでは、生成器が「識別器を騙すぞ!」と意気込んでるので、 生成器で生成したデータは識別器にかけるとオリジナル特徴量同様に大きな値を返して欲しいです。 なので$D_\phi(G_\theta(\boldsymbol{z}))$を加えます。

ポイントなのは識別器の更新と生成器の更新を何対何で行うかです。
だいたい5:1だったり1:1です。多分。

一応、今回はVAEを扱ってるので、エンコーダー $F$ のパラメータ $\psi$ も更新しておきます。
その目的関数$L_{vae}$を最大化すればOKです。
VAEじゃなければこの行は無くて良い訳ですが。
※$L_{vae}$は一個前の記事や他の人の記事を見てください

超ポイント1

さっき意気込みを書いた時、「生成器」を「VAEのデコーダー」と書きました。
生成器はVAEそのものではなくてVAEのデコーダーです。ここが重要です。
ここを誤ると生成器の更新でエンコーダーのパラメータも含めてしまうので学習がこけるっぽいです。
僕はこれに気付けず1週間悩みました・・・ははは

超じゃないけどまぁまぁポイント

図には書いてませんが、あらかじめ識別器とVAEは個別に学習しておきます。
まずVAEのみを適当なエポック数で学習し、その後に識別器のみを学習します。

#########################################################################

実装しまーす

###パラメータと最適化関数の設定

#coding:utf-8
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import sys, time, pickle
    
class VAWGAN(nn.Module):
  def __init__(self, I=39, J=64, K1=256, K2=128, D1=64, D2=32, L=3, lr=0.001):
    super().__init__()
    self.I, self.J, self.K, self.L = I, J, K, L
    self.lr=lr
    self.alpha = None
    self.stat = None

    rnd = np.random.randn
    # setting parameters for encoding
    k = np.asarray([[I, K1], [K1, K2], [K2, J]])
    scale = np.asarray([2./K1, 2./K2, 2./J])
    self.W_enc = \
        [nn.Parameter(torch.Tensor(rnd(k[l,0], k[l,1])*scale[l])) for l in range(L-1)]
    self.b_enc = \
        [nn.Parameter(torch.Tensor(np.zeros(k[l,1]))) for l in range(L-1)]

    self.W_mu = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
    self.b_mu = nn.Parameter(torch.Tensor(np.zeros(J)))
    self.W_lvar = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
    self.b_lvar = nn.Parameter(torch.Tensor(np.zeros(J)))

    # setting parameters for decoding
    scale = np.asarray([2./K2, 2./K1, 2./I])
    self.W_dec = \
        [nn.Parameter(torch.Tensor(rnd(k[-1-l,1], k[-1-l,0])*scale[l])) for l in range(L)]
    self.b_dec = \
        [nn.Parameter(torch.Tensor(np.zeros(k[-1-l,0]))) for l in range(L)]

    # setting parameters for discriminator
    k = np.asarray([[I, D1], [D1,D2], [D2, 1]])
    scale = np.asarray([2./D1, 2./D2, 2.])
    self.M = \
        [nn.Parameter(torch.Tensor(rnd(k[l,0], k[l,1])*scale[l])) for l in range(L)]
    self.c = \
        [nn.Parameter(torch.Tensor(np.zeros(k[l,1]))) for l in range(L)]
    

    self.encoder_params = self.W_enc + self.b_enc +\
                          [self.W_mu] + [self.b_mu] + [self.W_lvar] + [self.b_lvar] 
    self.decoder_params = self.W_dec + self.b_dec
    self.critic_params = self.M + self.c
    
    self.critic_optimizer = optim.Adam(self.critic_params, lr=lr)
    self.encoder_optimizer = optim.Adam(self.encoder_params, lr=lr)
    self.decoder_optimizer = optim.Adam(self.decoder_params, lr=lr)

基本的な設計は「【PyTorch】Variational AutoEncoder」(https://qiita.com/ryo_he_0/items/dd3ea0d46113ce088665)
と同じです。
エンコーダー、デコーダー、識別器(discriminatorって呼んだりcriticって呼んだり)の隠れ層は全て2層です。別になんだって良いです。各層の次元数も層の数も好きにしてください。

optimizerですが、ただのVAEならencoderとdecoderで分けたりせずOKですが
今回は敵対学習中だけデコーダーとエンコーダーで更新則が違うので分けます。
最適化関数にAdamを用いてますが、ここももちろん好きにしてください。学習率も同じです。

FeedForward

##############################################################################
  def discriminator(self, x):
    h = F.relu(self.linear(x, self.M[0], self.c[0]))
    for l in range(1, self.L-1):
      h = F.relu(self.linear(h, self.M[l], self.c[l]))

    return self.linear(h, self.M[-1], self.c[-1])
  ##############################################################################
  def encode(self, x):                
    N = x.shape[0]
    h = F.relu(self.linear(x, self.W_enc[0], self.b_enc[0]))
    for l in range(1, self.L-1):
      h = F.relu(self.linear(h, self.W_enc[l], self.b_enc[l]))

    z_mu = self.linear(h, self.W_mu, self.b_mu)
    z_lvar = self.linear(h, self.W_lvar, self.b_lvar)
                                              
    eps = torch.tensor(np.random.randn(N, self.J).astype(np.float32), requires_grad=False)
    z = eps * torch.sqrt(z_lvar.exp()) + z_mu
    return z, z_mu, z_lvar

##############################################################################
  def decode(self, h):
    for l in range(self.L-1):
      h = F.relu(self.linear(h, self.W_dec[l], self.b_dec[l]))
    y = self.linear(h, self.W_dec[-1], self.b_dec[-1])

    return y

##############################################################################
  def linear(self, x, w, b):
    return torch.matmul(x, w) + b

どれもFeedForward型で定義してますが、ここをconvolutionalにしても良いです。
僕はせっかちなのでconvolutionalみたいに時間がかかる手法はやる気が起きない為まだ勉強してません。
マシンのスペックの事情もありますし・・・

目的関数

  ##############################################################################
  def compute_vae_loss(self, y, z_mu, z_lvar, x):
    # reconstruction loss
    rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
    rec_loss = -0.5 * torch.mean(rec_loss_, axis=1)
    
    # KL loss
    latent_loss_ = 1 + z_lvar - z_lvar.exp() - torch.square(z_mu)
    latent_loss = 0.5 * torch.mean(latent_loss_, axis=1)
    loss = torch.mean(-1. * (rec_loss + latent_loss)) 

    return loss

  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#-
  def compute_adv_loss(self, y, z_mu, z_lvar, x):
    w2 = self.discriminator(y)

    # reconstruction loss
    rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
    rec_loss = -0.5 * torch.mean(rec_loss_, axis=1)
    
    E1 = torch.mean(rec_loss)
    E2 = torch.mean(w2)
    alpha = np.abs(E1.item() / E2.item())
    return -(E1 + alpha*E2)

  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#-
  def compute_critic_loss(self, x, y):
    w1 = self.discriminator(x)  # natural
    w2 = self.discriminator(y)  # generated

    e = torch.tensor(np.random.randn(x.shape[0], x.shape[1]), requires_grad=False)
    x_ = e * x + (1 - e) * y
    x_ = torch.tensor(x_.clone().detach().numpy().astype(np.float32), requires_grad=True)
    
    Dx_ = torch.sum(self.discriminator(x_))
    Dx_.backward()
    dx_ = x_.grad
    dx_ = torch.tensor(dx_.clone().detach().numpy().astype(np.float32))
    l2norm_dx_ = torch.sqrt(torch.mean(torch.square(dx_), axis=1))
    gp = torch.square(l2norm_dx_ - 1)
    self.critic_optimizer.zero_grad() # juuyou!!!
  
    # w2:D(y) =>  minimize w1:D(x) => maximize
    w1_ = w1.clone().detach().numpy().astype(np.float32)
    w2_ = w2.clone().detach().numpy().astype(np.float32)
    r = np.abs(np.mean(w1_)/np.mean(w2_))
    L = w1 - w2*r + torch.mean(w1-w2*r)/torch.mean(gp) * gp 
    loss = -1 * torch.mean(L)

    return loss

compute_vae_lossはただのVAE学習の目的関数です。


compute_adv_lossは敵対学習の目的関数です。
$L_{gen}$ と $D(G_\theta(\boldsymbol{z}))$ の和になりますが、
両者のスケールを合わせるために、$\alpha$を掛けます。これをしないと、片方のロスに飲み込まれます。


compute_critic_lossは識別器の目的関数です。
勾配消失を避けるため、gradient-penaltyという手法を用いてます。
無視したい人は

L = w1 - w2

と思ってください。ただし、Wasserstein GANは識別モデルがK-Lipschits関数であるという制約があるので、更新後の識別モデルパラメータをクリッピングして[-0.01, 0.01]に抑えてください。


さて、詳細はコードについてしか説明しませんが、gradient-penaluty法では$\nabla_{\boldsymbol{x'}}D(\boldsymbol{x'})$を使います。 このとき勾配は$\boldsymbol{x'}$に関する偏微分です。 .clone().detach().numpy()とか使ってますが、これはx_を求める際に辿ってきたネットワークの履歴を全て取り除きたかっただけです。 あくまでx_を求めるまでに辿って来たVAEの経路履歴は関係なく、あくまでこれからx_を識別器にかけるという事だけ考えます。

ちゃんとお勉強してないのでよくわかりませんが、途中で辿ってきた履歴のせいで、
知らないところで関係ないところの勾配やパラメータが参照されたり変更されてたら嫌なので
とりあえずx_をVAEから切り離してます。ははは

L = w1 - r*w2 + ~*gp

の r は、w1とw2間でスケールが違うのでそれを調整してます。w2の方が基本的にかなり小さいので。
gpに数字かけてるのも同じ理由です。gpは死ぬほど大きい値なので飲まれます。

あとどれも、本来「最大化」問題なので、PyTorchのライブラリの仕様に合わせて最小化問題に帰着するため-1をかけます。

識別器の学習

##############################################################################
  def critic_train(self, train_x, nepoch=5, nbatch=128):
    N = train_x.shape[0]
    if self.stat is None:
      train_x_, mm, std = self.normalize(train_x)
      stat = {'mm':mm, 'std':std}
      self.stat = stat
    else:
      train_x_ = (train_x - self.stat['mm'])/self.stat['std']
    
    ### critic training ###
    print('discriminator training')
    for epoch in range(nepoch):
      perm = np.random.permutation(N)

      for i in range(0, N-nbatch, nbatch):
        batch_x = train_x_[perm[i:i + nbatch]]
        x_ = torch.tensor(batch_x.astype(np.float32), requires_grad=False)
      
        z, z_mu, z_lvar = self.encode(x_)
        y = self.decode(z)
      
        # update discriminator
        self.critic_optimizer.zero_grad()
        critic_loss = self.compute_critic_loss(x_, y)
        critic_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

ただのDNN学習みたいなものです。

###VAEの学習

  ##############################################################################
  def vae_train(self, train_x, nepoch=25, nbatch=128):
    N = train_x.shape[0]
    if self.stat is None:
      train_x_, mm, std = self.normalize(train_x)
      stat = {'mm':mm, 'std':std}
      self.stat = stat
    else:
      train_x_ = (train_x - self.stat['mm'])/self.stat['std']
    
    ### vae training ###
    print('vae training')
    for epoch in range(nepoch):
      perm = np.random.permutation(N)

      for i in range(0, N-nbatch, nbatch):
        batch_x = train_x_[perm[i:i + nbatch]]
        x_ = torch.tensor(batch_x.astype(np.float32), requires_grad=False)
      
        z, z_mu, z_lvar = self.encode(x_)
        y = self.decode(z)
      
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        vae_loss = self.compute_vae_loss(y, z_mu, z_lvar, x_)
        vae_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()        

ただのVAE学習です。

敵対学習

  ##############################################################################
  def adv_train(self, train_x, nepoch=25, nbatch=128):
    N = train_x.shape[0]                  
    if self.stat is None:
      train_x_, mm, std = self.normalize(train_x)
      stat = {'mm':mm, 'std':std}
      self.stat = stat
    else:
      train_x_ = (train_x - self.stat['mm'])/self.stat['std']


    ### adversarial training ###
    print('adversarial training')
    for epoch in range(nepoch):
      perm = np.random.permutation(N)

      for i in range(0, N-nbatch, nbatch):
        batch_x = train_x_[perm[i:i + nbatch]]
        x_ = torch.tensor(batch_x.astype(np.float32), requires_grad=False)
      
        z, z_mu, z_lvar = self.encode(x_)
        y = self.decode(z)
      
        # update discriminator
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        
        critic_loss = self.compute_critic_loss(x_, y)
        critic_loss.backward(retain_graph=True)
        self.critic_optimizer.step()
        

        # update VAE
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()

        adv_loss = self.compute_adv_loss(y, z_mu, z_lvar, x_)
        vae_loss = self.compute_vae_loss(y, z_mu, z_lvar, x_)
        adv_loss.backward(retain_graph=True)
        vae_loss.backward()
        self.decoder_optimizer.step()
        self.encoder_optimizer.step()

識別器を1バッチ更新したらVAEも1バッチ更新してます。
encoder_optimizer.zero_grad()〜で各optimizerの勾配をゼロにしてますが、
そんな毎回全てのoptimizerの勾配をゼロにしないといけないかはちゃんと考えてないのでわかりません。ははは。
とりあえず更新はそれぞれ独立に行うので毎回全てゼロにしておけば間違いはないでしょってスタンスです。ははh

超ポイント2

adv_trainの下の方、

adv_loss.backward(retain_graph=True)
vae_loss.backward()

としてます。ここ、backwardを続けて書かず、間に~.step()を入れてパラメータ弄ると詰みます。
エラー出て意味わからず泣きそうでした。
原因よく知りませんが、in-placeな書式はダメだよっていうエラーだったんですが、
そんな書き方してねーしっていう。
ダメ元でbackwardを間髪入れず並べたら走ったのでマジで意味わかってません。

超ポイント3

何回か言ってますが、backward(retain_graph=True)のように、retain_graph=Trueが無いとエラー吐きます。
詳細よく知りませんが、そのあともすぐbackward使うならそうしとけってお婆ちゃんが言ってました。
詳細調べたんですが忘れました。はhh・・・

動かしてみた

学習データはメルケプストラム39次元、データ数は4~5万ほど。学習条件はコードに書いてあるデフォルト値とします。

VAE vs VAWGAN

image.png
上図がVAWGANで生成した特徴量で、下がVAEです。振幅がだいぶ蘇ってますね。


![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/156754/8ca853fc-063b-df13-f0eb-83df87a8cc50.png) 特徴量の散布図です。VAEは過剰平滑化により散布図が縮こまってましたが VAWGANはそれが解消されてます。

VAWGANのVAE と Discriminatorを見てみる

image.png
上図はVAEの目的関数です。最初の25エポックは事前学習で、後半の25エポックは敵対学習です。
26回目からは敵対学習で$D_\phi(G_\theta(\boldsymbol{z}))$をデコーダー更新時に髪してるので、後半になって最初はロスがグッと上がってます。
ただそのあときちんと下がっていってますね。

下図は$D_\phi(\boldsymbol{x})$ (青)と $D_\phi(\hat{\boldsymbol{x}})$ (赤)です。

image.png
よく見ると、最初の5エポック(識別器の事前学習)ではきちんと識別器の目的を果たせてます。


# おわり 読んでいただきありがとうございます。 何か間違い等あれば指摘していただけると幸いです。 自分自身、まだまだ勉強中なので何卒よろしくお願い申し上げます。
10
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
15