PyTorchでVAEを書く機会があったのでメモ.
コードで見てみると,とってもシンプルなアーキテクチャなのだな,と実感した.
今回は白黒画像でチャネル数が1なので,Convolutionは使わずにLinearだけ単純に繰り返しているだけ.一応,バッチ正規化している.正直,最適なチューニングの方法は分からない.
class VAE(nn.Module):
def __init__(self, z_dim):
super(VAE, self).__init__()
self.dense_enc1 = nn.Linear(28*28, z_dim**2)
self.dense_encbn1 = nn.BatchNorm1d(z_dim**2)
self.dense_enc2 = nn.Linear(z_dim**2, z_dim**2)
self.dense_encbn2 = nn.BatchNorm1d(z_dim**2)
self.dense_enc3 = nn.Linear(z_dim**2, z_dim**2)
self.dense_encbn3 = nn.BatchNorm1d(z_dim**2)
self.dense_encmean = nn.Linear(z_dim**2, 10)
self.dense_encvar = nn.Linear(z_dim**2, 10)
self.dense_dec1 = nn.Linear(10, z_dim**2)
self.dense_decbn1 = nn.BatchNorm1d(z_dim**2)
self.dense_dec2 = nn.Linear(z_dim**2, z_dim**2)
self.dense_decbn2 = nn.BatchNorm1d(z_dim**2)
self.dense_dec3 = nn.Linear(z_dim**2, 28*28)
def _encoder(self, x):
x = F.relu(self.dense_enc1(x))
x = self.dense_encbn1(x)
x = F.relu(self.dense_enc2(x))
x = self.dense_encbn2(x)
x = F.relu(self.dense_enc3(x))
x = self.dense_encbn3(x)
mean = self.dense_encmean(x)
var = F.softplus(self.dense_encvar(x))
return mean, var
def _sample_z(self, mean, var):
epsilon = torch.randn(mean.shape).to(device)
return mean + torch.sqrt(var) * epsilon
def _decoder(self, z):
x = F.relu(self.dense_dec1(z))
x = self.dense_decbn1(x)
x = F.relu(self.dense_dec2(x))
x = self.dense_decbn2(x)
x = torch.sigmoid(self.dense_dec3(x))
return x
def forward(self, x):
mean, var = self._encoder(x)
z = self._sample_z(mean, var)
x = self._decoder(z)
return x, z
def loss(self, x):
mean, var = self._encoder(x)
# KL lossの計算
KL = -0.5 * torch.mean(torch.sum(1 + torch_log(var) - mean**2 - var, dim=1))
z = self._sample_z(mean, var)
y = self._decoder(z)
# reconstruction lossの計算
reconstruction = torch.mean(torch.sum(x * torch_log(y) + (1 - x) * torch_log(1 - y), dim=1))
return KL, -reconstruction
z_dim = 17
n_epochs = 2000
init_lr = 5e-04
model = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=init_lr)
scheduler = OneCycleLR(optimizer, max_lr=init_lr, steps_per_epoch=len(dataloader_train), epochs=n_epochs, pct_start=0.2)
for epoch in range(n_epochs):
losses = []
KL_losses = []
reconstruction_losses = []
model.train()
for x in dataloader_train:
x = x.to(device)
model.zero_grad()
KL_loss, reconstruction_loss = model.loss(x)
loss = KL_loss + reconstruction_loss
loss.backward()
optimizer.step()
scheduler.step()
losses.append(loss.cpu().detach().numpy())
KL_losses.append(KL_loss.cpu().detach().numpy())
reconstruction_losses.append(reconstruction_loss.cpu().detach().numpy())
losses_val = []
model.eval()
for x in dataloader_valid:
x = x.to(device)
KL_loss, reconstruction_loss = model.loss(x)
loss = KL_loss + reconstruction_loss
losses_val.append(loss.cpu().detach().numpy())
print('EPOCH:%d, Train Lower Bound:%lf, (%lf, %lf), Valid Lower Bound:%lf' %
(epoch+1, np.average(losses), np.average(KL_losses), np.average(reconstruction_losses), np.average(losses_val)))