#Variational AutoEncoderについて
これで何かするのは結構大変だなぁという印象です。
というのも、実装自体は難しくないのですが、
学習を円滑に進めるためのハイパーパラメータや初期値の設定が大変です。
VAEはエンコード及びデコードの途中で、隠れ層の活性化関数に「ReLu」を使います。
ReLu(x) = max(x, 0)
つまり、その過程でほとんどが0になる可能性があります。そうなれば勾配が消失し、学習が止まってしまいます。その結果、再構築した特徴量までスパースになります。
このあと実装メモを書きながら知見や調べたことのメモも書きます。
####################################################
#まず定義
class VAE(nn.Module):
def __init__(self, I=39, J=64, K1=256, K2=128, lr=0.001):
super().__init__()
self.I, self.J, self.K1, self.K2 = I, J, K1, K2
self.lr = lr
self.stat = None
self.error = None
scale = np.asarray([2./K1, 2./K2, 2./J])
rnd = np.random.randn
k = np.asarray([[I, K1], [K1, K2], [K2, J]])
# setting parameters for encoding
self.W_enc = \
[nn.Parameter(torch.Tensor(rnd(k[l,0], k[l,1])*scale[l])) for l in range(L-1)]
self.b_enc = \
[nn.Parameter(torch.Tensor(np.zeros(k[l,1]))) for l in range(L-1)]
self.W_mu = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
self.b_mu = nn.Parameter(torch.Tensor(np.zeros(J)))
self.W_lvar = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
self.b_lvar = nn.Parameter(torch.Tensor(np.zeros(J)))
# setting parameters for decoding
scale = np.asarray([2./K2, 2./K1, 2./I])
self.W_dec = \
[nn.Parameter(torch.Tensor(rnd(k[-1-l,1], k[-1-l,0])*scale[L-1-l])) for l in range(L)]
self.b_dec = \
[nn.Parameter(torch.Tensor(np.zeros(k[-1-l,0]))) for l in range(L)]
self.params = [self.W_mu] + [self.b_mu] + [self.W_lvar] + [self.b_lvar] + \
self.W_enc + self.b_enc + self.W_dec + self.b_dec
self.optimizer = optim.Adam(self.params, lr=lr)
今回はエンコーダーとデコーダー共に隠れ層は2層とします。
入力層、隠れ層1、隠れ層2、潜在特徴量の次元がそれぞれI、K1、K2、Jになります。
lrは学習率です。
statはあとで使います。
scaleはWeightパラメータの初期値の分散をどれだけ抑えるかというハイパーパラメータです。
kはl番目のレイヤーのinput次元とoutput次元を表してます。
paramsは全てのネットワークパラメータを1つのリストとしてまとめ上げてます。
optimizerは最適化手法で今回はAdamです。第一引数は更新したいパラメータ(リスト)です。
この更新したいパラメータの定義は「【PyTorch】自由にネットワークを設計するよ」(https://qiita.com/ryo_he_0/items/ad5c7b15e0d280ef32e3 )を読んでください。
さて、VAEのポイントはWeightの初期値のスケールと学習率とネットワークの大きさです。
- Weightの初期値:どれだけ分散を小さくするか
- 学習率:どれだけ小さくするか
- ネットワークの大きさ:各層の次元数をどれだけ大きくするか
各層の次元数が大きくなると、次の層に伝播される際の振り幅が大きくなります。
なのでこのとき負の方向に一気に振れてしまうとReluで軒並み0にされてしまうので
Weightの分散をその分小さくしてあげると良い感じになります。
例えば今回みたいに$(\frac{2}{K})^2$にしてあげるとかです。
とは言え分散を小さくし過ぎてしまうと今度はWeightがほぼゼロ行列になってしまうので要注意です。
こればっかりは学習率、分散、次元数のそれぞれによって最適な組み合わせは何とも言えないので
色々やってみるしかないですね。
################################################################
#FeedForward
def encode(self, x):
N = x.shape[0]
h = F.relu(self.linear(x, self.W_enc[0], self.b_enc[0]))
for l in range(1, self.L-1):
h = F.relu(self.linear(h, self.W_enc[l], self.b_enc[l]))
z_mu = self.linear(h, self.W_mu, self.b_mu)
z_lvar = self.linear(h, self.W_lvar, self.b_lvar)
eps = torch.tensor(np.random.randn(N, self.J).astype(np.float32), requires_grad=False)
z = eps * torch.sqrt(z_lvar.exp()) + z_mu
return z, z_mu, z_lvar
def decode(self, z):
for l in range(self.L-1):
h = F.relu(self.linear(h, self.W_dec[l], self.b_dec[l]))
y = self.linear(h, self.W_dec[-1], self.b_dec[-1])
return y
def linear(self, x, w, b):
return torch.matmul(x, w) + b
###エンコーダー
encodeはエンコーダーです。入力$\boldsymbol{x}$(N行I列)から
h = F.relu(self.linear(・, self.W_enc[l], self.b_enc[l]))
の過程を経て末端の隠れ層の値を得ます。
そして潜在特徴量の平均パラメータ及び対数分散パラメータのz_muとz_lvarを得ます。
潜在特徴量は平均z_mu、分散$\exp(z_lvar)$の正規分布からサンプリングされたノイズと定義されます。
###デコーダー
decodeはデコーダーです。こやつは潜在特徴量から可視特徴量(入力特徴量)の期待値を返します。
################################################################
#目的関数
def compute_loss(self, y, z_mu, z_lvar, x):
# reconstruction loss
rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
rec_loss = -0.5 * torch.sum(rec_loss_, axis=1)
# KL loss
latent_loss_ = 1 + z_lvar - z_lvar.exp() - torch.square(z_mu)
latent_loss = 0.5 * torch.sum(latent_loss_, axis=1)
loss = torch.mean(-1. * (rec_loss + latent_loss))
return loss
xは入力特徴量、yは再構築値です。z_mu、z_lvarはそれぞれエンコーダーで得られた潜在特徴量分布の平均と対数の分散です。
################################################################
#コード全体
class VAE(nn.Module):
def __init__(self, I=39, J=64, K1=256, K2=128, lr=0.001):
super().__init__()
self.I, self.J, self.K1, self.K2 = I, J, K1, K2
self.lr = lr
self.stat = None
self.error = error
scale = np.asarray([2./K1, 2./K2, 2./J])
rnd = np.random.randn
k = np.asarray([[I, K1], [K1, K2], [K2, J]])
# setting parameters for encoding
self.W_enc = \
[nn.Parameter(torch.Tensor(rnd(k[l,0], k[l,1])*scale[l])) for l in range(L-1)]
self.b_enc = \
[nn.Parameter(torch.Tensor(np.zeros(k[l,1]))) for l in range(L-1)]
self.W_mu = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
self.b_mu = nn.Parameter(torch.Tensor(np.zeros(J)))
self.W_lvar = nn.Parameter(torch.Tensor(rnd(k[-1,0], k[-1,1])*scale[-1]))
self.b_lvar = nn.Parameter(torch.Tensor(np.zeros(J)))
# setting parameters for decoding
scale = np.asarray([2./K2, 2./K1, 2./I])
self.W_dec = \
[nn.Parameter(torch.Tensor(rnd(k[-1-l,1], k[-1-l,0])*scale[l])) for l in range(L)]
self.b_dec = \
[nn.Parameter(torch.Tensor(np.zeros(k[-1-l,0]))) for l in range(L)]
self.params = [self.W_mu] + [self.b_mu] + [self.W_lvar] + [self.b_lvar] + \
self.W_enc + self.b_enc + self.W_dec + self.b_dec
self.optimizer = optim.Adam(self.params, lr=lr)
#------------------------------------------------------------#
def encode(self, x):
N = x.shape[0]
h = F.relu(self.linear(x, self.W_enc[0], self.b_enc[0]))
for l in range(1, self.L-1):
h = F.relu(self.linear(h, self.W_enc[l], self.b_enc[l]))
z_mu = self.linear(h, self.W_mu, self.b_mu)
z_lvar = self.linear(h, self.W_lvar, self.b_lvar)
eps = torch.tensor(np.random.randn(N, self.J).astype(np.float32), requires_grad=False)
z = eps * torch.sqrt(z_lvar.exp()) + z_mu
return z, z_mu, z_lvar
#------------------------------------------------------------#
def decode(self, z):
for l in range(self.L-1):
h = F.relu(self.linear(h, self.W_dec[l], self.b_dec[l]))
y = self.linear(h, self.W_dec[-1], self.b_dec[-1])
return y
#------------------------------------------------------------#
def linear(self, x, w, b):
return torch.matmul(x, w) + b
#------------------------------------------------------------#
def compute_loss(self, y, z_mu, z_lvar, x):
# reconstruction loss
rec_loss_ = np.log(2*np.pi) + torch.square(x - y)
rec_loss = -0.5 * torch.sum(rec_loss_, axis=1)
# KL loss
latent_loss_ = 1 + z_lvar - z_lvar.exp() - torch.square(z_mu)
latent_loss = 0.5 * torch.sum(latent_loss_, axis=1)
loss = torch.mean(-1. * (rec_loss + latent_loss))
return loss
#------------------------------------------------------------#
def train(self, train_x, nepoch=50, nbatch=128):
N = train_x.shape[0]
train_x_, mm, std = self.normalize(train_x)
stat = {'mm':mm, 'std':std}
self.stat = stat
print('--------------------------------------')
print('%d frame found!'%N)
error = []
for epoch in range(nepoch):
perm = np.random.permutation(N)
error_batch = []
for i in range(0, N-nbatch, nbatch):
batch_x = train_x_[perm[i:i + nbatch]]
x_ = torch.tensor(batch_x.astype(np.float32), requires_grad=False)
self.optimizer.zero_grad()
z, z_mu, z_lvar = self.encode(x_)
y = self.decode(z)
loss = self.compute_loss(y, z_mu, z_lvar, x_)
loss.backward()
self.optimizer.step()
error_batch.append(loss.item())
error.append(np.mean(np.asarray(error_batch)))
self.error = error
#------------------------------------------------------------#
def normalize(self, x, mm=None, std=None):
if mm is None:
mm = np.mean(x, axis=0)
if std is None:
std = np.std(x, axis=0)
return (x-mm)/std, mm, std
#------------------------------------------------------------#
def reconst(self, x):
x_ = (x - self.stat['mm']) / self.stat['std']
x_ = torch.tensor(x_.astype(np.float32), requires_grad=False)
z, mu, lvar = self.encode(x_)
y = self.decode(z)
y = y.detach().numpy()
return y * self.stat['std'] + self.stat['mm']
################################################################
#出力
今回は音声のスペクトル包絡から得られる0次を除く39次元のメルケプストラムをVAEに食わせます。
学習に用いた総フレーム数は80000くらいで、学習条件は先に書いた実装コードのものです。
負の尤度(目的関数の値)です。
21次のメルケプストラムの入力(青線)と再構築値(赤線)がこんな感じです。
生成モデルらしく振幅が小さくなっており(皮肉です)、ちゃんと動いてそうです。
実際に再合成した音声も生成モデルらしくブザーっぽいノイズを乗せながらいい感じでした。
こちらはメルケプストラムの散布図です。左が入力メルケプストラム、右がVAEによる再構築値です。
こちらも生成モデルらしく散布図の散らばりが小さくなってます(皮肉です)。
##################################################################
#おわり
ということで今回はPyTorchでVAEを実装しました。ReLUを活性化関数に使う際は色々悩むポイントがあるというのが本記事で言いたかったメモです。
また、次回?次々回?気が向いたら「再構築値の散らばりが小さくなる」という問題点を解決する方法について書きます。
(みんな大好きなGenerative Adversarial Nets(GANs)です)