2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Gated CNN + VAEで強くなる

Last updated at Posted at 2021-06-19

Gated CNN + VAEで強くなる

目次

飛べます
今回はGated CNN + GAN
Gated CNN(GCNN)
VAEに組み込んでみる
GANで強くなる
ちなみに

今回はGated CNN + VAE

今までVAEについて書いてきました。ただ、
-``ある時刻 $t$ におけるフレーム''の入出力しか考慮していない

  • 特徴量次元間の相関しか考慮してない
    という感じです。

当たり前ですが音声特徴量は近隣範囲のフレームに跨って影響を及ぼし合ってるはずです。
なので、特徴量次元間の相関だけじゃなく時間方向の相関も捉えられるモデルが必要です。

そこでCNNを使いましょうっていう話です。

今回はLSTMのようなゲート機構を持ったCNNを用います。

ちなみにGANも組み込もうとしましたが全然上手くいかず3ヵ月経ちました。

Gated CNN(GCNN)

$\boldsymbol{h_l} = (\boldsymbol{W*h_{l-1}} + \boldsymbol{b_l}) \otimes \sigma(\boldsymbol{V*h_{l-1}} + \boldsymbol{d_l})$

という伝播で、$\boldsymbol{h_l}$に対してBatch Normalizationを施した$\boldsymbol{h_l'}$が$l$層目の中間層になります。
上の伝播において、前半部分が普通の畳み込み層で、後半がシグモイド関数で制御されたゲートです。

VAEに組み込んでみる

##主要なコードだけ
デコーダーは下のような感じです。
中間層を3層にしてます。エンコーダーはこの逆をやれば良いです。
nn.ConvTranspoose2d( a, b, (c, d) ~)のa, bは入力チャンネル数と出力チャンネル数です。

あ、この辺の引数については他所で丁寧にたくさん説明してる記事があるのでここでは割愛します。

ちなみに引数は適当です。参考にした論文はありますが、辻褄合わなかったので微妙に変えてます。
paddingはマジで帳尻合わせのために数値決めてます。

class Decoder(nn.Module):
  def __init__(self, lr=0.001, betas=(.9, .999)): 
    super().__init__()

    # setting for decoder convolutional parameters
    self.decW1 = nn.ConvTranspose2d(5,16,(5,9),stride=(1,9), padding=(3,1))
    self.decW1_bn = nn.BatchNorm2d(16)
    self.decW2 = nn.ConvTranspose2d(16,16,(8,4),stride=(2,2),padding=(3,0))
    self.decW2_bn = nn.BatchNorm2d(16)
    self.decW3 = nn.ConvTranspose2d(16,8,(8,4),stride=(2,2),padding=(3,0))
    self.decW3_bn = nn.BatchNorm2d(8)
    self.decW4 = nn.ConvTranspose2d(8,1,(9,3),stride=(1,1))
    
    # setting for decoder gate parameters
    self.decV1 = nn.ConvTranspose2d(5,16,(5,9),stride=(1,9), padding=(3,1))
    self.decV1_bn = nn.BatchNorm2d(16)
    self.decV2 = nn.ConvTranspose2d(16,16,(8,4),stride=(2,2), padding=(3,0))
    self.decV2_bn = nn.BatchNorm2d(16)
    self.decV3 = nn.ConvTranspose2d(16,8,(8,4),stride=(2,2), padding=(3,0))
    self.decV3_bn = nn.BatchNorm2d(8)
    self.decV4 = nn.ConvTranspose2d(8,1,(9,3),stride=(1,1))

    self.optimizer = optim.Adam(self.parameters(), lr=lr)


  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---
  def forward(self, z):
    h1 = torch.mul(self.decW1_bn(self.decW1(z)), torch.sigmoid(self.decV1_bn(self.decV1(z))))
    h2 = torch.mul(self.decW2_bn(self.decW2(h1)), torch.sigmoid(self.decV2_bn(self.decV2(h1))))
    h3 = torch.mul(self.decW3_bn(self.decW3(h2)), torch.sigmoid(self.decV3_bn(self.decV3(h2))))
    mu = self.decW4(h3)
    lvar = self.decV4(h3)

    return mu, lvar

メインはこんな感じです。

class GCNN(nn.Module):
  def __init__(self, lr=1e-3):
    super().__init__()
    self.vae_loss = []

    self.encoder = Encoder(lr=lr)
    self.decoder = Decoder(lr=lr)

  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---
  def compute_vae_loss(self, x, z_mu, z_lvar, y):
    # p(x)
    rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
    rec_loss = -0.5 * torch.sum(rec_loss_, axis=(1,2,3))

    # KL div
    latent_loss_ = 1 + z_lvar - z_lvar.exp() - torch.square(z_mu)
    latent_loss = 0.5 * torch.sum(latent_loss_, axis=(1,2,3))
    
    return -1. * torch.mean(rec_loss + latent_loss)

そんで、学習はこんな感じです。同じGCNNクラスのメソッドです。 学習データXは、$X = [X_1, X_2, ..., X_N]$のようなリストになっていて、$X_n$は512行36列の2次元配列です。 ミニバッチは8、エポック数は50にしてます。25で十分に収束しますが。 食わせたデータは36次元のメルケプストラムで、総データ数5600くらいです。 1000発話くらいを窓幅512フレーム、オーバーラップ256フレームで分割してます。
  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---
  def vae_train(self, X, nepoch=25, nbatch=8):
    N, H, D = len(X), X[0].shape[0], X[0].shape[1]
    batch_array = np.arange(0, N-nbatch, nbatch) # [0, 8, 16, ... ]
    
    print('vae training (nepoch: %d  nbatch: %d  ndata: %d)'%(nepoch, nbatch, N))
    for epoch in range(nepoch):
      perm = np.random.permutation(N) # [0, N)をランダムに並び替えた配列

      loss_batch = np.zeros(len(batch_array))
      for k in range(len(batch_array)):
        i = batch_array[k]
        batch_x = X_[perm[i:i+nbatch]]
        batch_x_ = torch.tensor(batch_x.astype(np.float32), requires_grad=False)

        self.encoder.optimizer.zero_grad()
        self.decoder.optimizer.zero_grad()
    
        z, z_mu, z_lvar = self.encoder.forward(batch_x_)
        y_, _ = self.decoder.forward(z)
    
        l = self.compute_vae_loss(batch_x_, z_mu, z_lvar, y_)
        l.backward()
        
        self.encoder.optimizer.step()
        self.decoder.optimizer.step()
        loss_batch[k] = l.item()

      self.vae_loss.append(np.mean(np.mean(loss_batch)))

知見① 損失関数

前回もそうでしたがVAEの損失計算時に見にバッチ方向には平均をとってますが、特徴量次元方向には和を取ってます。
和を取っても特徴量次元数で割れば平均取った場合と等価なので意味合いは違いませんが
lossのスケールが違います。今回なら512x36=18432で10の4乗も違います。

今回は平均取ると学習がうまくいかず、和を取ったらうまくいきました。難しいっすね。

VAEに組み込んでみた結果

なんかめっちゃ上手くできてます。マジで。
Fully Connection型の通常VAE(【PyTorch】Variational AutoEncoder)よりも過剰平滑化が緩和されており音質も良かったです。

ここで満足しましたがGANでもっと良くなるなら・・・と。
image.png

GANで強くなる

お待ちかねのGANです。今回はLeast Square GAN (LSGAN) を用いました。
前回の(過去URL)ではWasserstein GANでしたが処理時間が短いのでLSGANを採用してます。

一応説明しておきますが、LSGANは

  • 自然音声なら1
  • 生成音声なら0
    を返す識別器です。
    なので識別器の損失$\mathcal{L}_D$、生成器の損失$\mathcal{L}_G$はそれぞれ
\begin{align}
\mathcal{L}_D &= \mathbb{E}[||D(\boldsymbol{x}) - 1||_2] + \mathbb{E}[||D(\boldsymbol{y}) - 0||_2] \\
\mathcal{L}_G &= \mathcal{L}_{decoder} + \mathbb{E}[||D(\boldsymbol{y}) - 1||_2]
\end{align}

になります。

LSGANについて主要コード

さっきのメインクラスGCNNをGCGANとかにして、Discriminatorクラスも呼び出します。
今回はcriticっていう名前で呼び出してます。

当初はWasserstein GANで実装してたのでcriticのままになってます笑

is_lsは「label smoothing」をするかどうかの旗です。
label smoothingについては後述します。

forwardメソッドのis_feature_matchingはfeature matchingを用いるかどうかで、これも後述します。

################################################################################
class Discriminator(nn.Module):
  def __init__(self, lr=0.0001, is_ls=False, is_fm=False):
    super().__init__()

    # setting for discriminator
    self.W1 = nn.Conv2d(1, 4, (8,4), stride=(4,2), padding=(3,1))
    self.V1 = nn.Conv2d(1, 4, (8,4), stride=(4,2), padding=(3,1))
    self.W1_bn = nn.BatchNorm2d(4)
    self.V1_bn = nn.BatchNorm2d(4)
    self.W2 = nn.Conv2d(4, 4, (8,4), stride=(2,2), padding=(3,1))
    self.V2 = nn.Conv2d(4, 4, (8,4), stride=(2,2), padding=(3,1))
    self.W2_bn = nn.BatchNorm2d(4)
    self.V2_bn = nn.BatchNorm2d(4)
    self.W3 = nn.Conv2d(4, 4, (8,4), stride=(4,2), padding=(3,0))
    self.V3 = nn.Conv2d(4, 4, (8,4), stride=(4,2), padding=(3,0))
    self.fc4 = nn.Conv1d(4, 1, (16,1), stride=(8), padding=(0,0))

    self.optimizer = optim.Adam(self.parameters(), lr=lr, betas=betas)
    self.is_ls = is_ls
    self.is_fm = is_fm

    if is_ls:
      self.a, self.b, self.c = .9, 0.1, 1
    else:
      self.a, self.b, self.c = 1., 0., 1.

  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---
  def forward(self, input):
    h1 = torch.mul(self.W1_bn(self.W1(input)), torch.sigmoid(self.V1_bn(self.V1(input))))
    h2 = torch.mul(self.W2_bn(self.W2(h1)), torch.sigmoid(self.V2_bn(self.V2(h1)))) 
    h3 = torch.mul(self.W3(h2), torch.sigmoid(self.V3(h2)))

    d = self.fc4(h3)

    if self.is_fm:
      return d, h3
    else:
      return d

損失はこんな感じです。

  #---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---#---
  def compute_adv_loss(self, x, z, z_mu, z_lvar, y, lam=1.):
    z_ = torch.tensor(z.clone().detach().numpy().astype(np.float32), requires_grad=False)
    y_, _ = self.decoder.forward(z_)

    # p(x)
    rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
    rec_loss = -0.5 * torch.sum(rec_loss_, axis=(1,2,3))

    # KL div
    latent_loss_ = 1 + z_lvar - z_lvar.exp() - torch.square(z_mu)
    latent_loss = 0.5 * torch.sum(latent_loss_, axis=(1,2,3))

    E1g = torch.mean(-1. * rec_loss) # decoder
    E1e = torch.mean(-1. * latent_loss) # encoder


    # ADV
    if self.critic.is_fm:
      w2, h2 = self.critic.forward(y_)
      w1, h1 = self.critic.forward(x)
      E2 = torch.mean(torch.sum(torch.square(h2 - h1), axis=(1,2,3)))
    else:  
      w2 = self.critic.forward(y_)
      E2 = torch.mean(torch.sqrt(torch.square(w2-self.critic.c))) # decoder, discriminator
    
    alpha = lam * np.abs(E1g.item()/E2.item())

    loss = E1g + E1e + alpha*E2
    return loss, (E1g+E1e).item()

知見② Adamのパラメータ

pytorchのデフォルトだと$\beta_1$と$\beta_2$はそれぞれ0.9と0.999です。
$\beta_1$を0.5にすると良い感じになります。

知見③ 学習バランス

識別器と生成器の学習バランスですが、よく見るのは「識別器をある程度更新してから生成器を1回更新」っていう流れです。
でもこれって難しくて、識別器が最強になるともう生成器はへんな音声しか産まなくなる場合があります。
なので今回は1:1でやってます。僕の場合はこれが一番良かった気がします。

結果

事前学習としてVAEをエポック数25、ミニバッチ8、学習率0.001、OptimizerにAdam(モメンタム0.9)で学習し、
識別器をエポック数5、ミニバッチ8、学習率0.0001、OptimimzerにAdam(モメンタム0.9)で学習しました。

本番の学習は、エポック数25、ミニバッチ8、VAEの学習率を0.001、識別器の学習率を0.00002、Optimizerは共にAdam(モメンタム0.5)で学習しました。

メルケプストラムの概形や散布図を見ると、あんま敵対学習の効能が見られません。
「は?」と思いましたが変調スペクトルは補償されてるので無駄では無さそうっす。
音質は良くなった点もありつつ劣化した部分もあってとんとんかなー?

あと、よーく見ると高次成分が復元されてる感じ。
微細変動が蘇ってます。

image.png

識別器の出力が0.5付近じゃなくて0.2付近で拮抗してるので、生成器はあんまですね。

ただ、今回はそもそも音質が良いので学習が難しいGANを使う必要なくて
もっと音質が崩壊するようなシチュエーションのときにこそ使うべきなのかなとも思いました。

それにしても識別器が死んでますねー笑

ちなみに

今回はGANの効能が弱っちいと思い、様々な学習アルゴリズムも試しました。まぁ最終的にはシンプルな方法で今回は十分だったので全て外しましたが。

Feature Matching

識別器の学習の際、D(x)とD(y)を用いるのではなく、$\mathcal{L}_D = ||f(\boldsymbol{x})^l - f(\boldsymbol{y})^l||_2$とします。
$f(\boldsymbol{x})^l$はDの$l$層目における出力です。
$l=0$ならただの再構築誤差と等価です。

label smoothing

識別器の学習の際、自然音声に対しては1、生成音声に対しては0を、
生成器の学習の際、敵対項の生成音声に対しては1を返すのが普通のLSGANです。
ただ、これを1, 0ではなく0.9, 0.1など識別器のラベルを濁す方が精度が良いらしいです。
片方だけを連続値にするのをOne-sided label smoothingと言って双方の場合はTwo-side~って言うそうです。

mini-batch Standard deviation

GANの問題の一つに「モード崩壊(mode collapse)があります。
これは生成器がひたすら同じ出力をしてしまうっていうやつです。
これが起きるとどんなデータを食わせても再構築したメルケプは毎回同じような値になります。

そこで!
「生成音声はミニバッチ間の分散が小さくなる」ことを識別器の学習に加えます。
詳細は(https://medium.com/@hirokisakuma1209/gansynth-adversarial-neural-audio-synthesis-4dcc4e4ac9bb)

おしまい

Gated CNNは最強でした。
GANは難しい。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?