0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

VAEを書いてみる【PyTorch】

Posted at

PyTorchでVAEを書く機会があったのでメモ.
コードで見てみると,とってもシンプルなアーキテクチャなのだな,と実感した.
今回は白黒画像でチャネル数が1なので,Convolutionは使わずにLinearだけ単純に繰り返しているだけ.一応,バッチ正規化している.正直,最適なチューニングの方法は分からない.

class VAE(nn.Module):
    def __init__(self, z_dim):
        super(VAE, self).__init__()
        self.dense_enc1 = nn.Linear(28*28, z_dim**2)
        self.dense_encbn1 = nn.BatchNorm1d(z_dim**2)
        self.dense_enc2 = nn.Linear(z_dim**2, z_dim**2)
        self.dense_encbn2 = nn.BatchNorm1d(z_dim**2)
        self.dense_enc3 = nn.Linear(z_dim**2, z_dim**2)
        self.dense_encbn3 = nn.BatchNorm1d(z_dim**2)

        self.dense_encmean = nn.Linear(z_dim**2, 10)
        self.dense_encvar = nn.Linear(z_dim**2, 10)

        self.dense_dec1 = nn.Linear(10, z_dim**2)
        self.dense_decbn1 = nn.BatchNorm1d(z_dim**2)
        self.dense_dec2 = nn.Linear(z_dim**2, z_dim**2)
        self.dense_decbn2 = nn.BatchNorm1d(z_dim**2)
        self.dense_dec3 = nn.Linear(z_dim**2, 28*28)

    def _encoder(self, x):
        x = F.relu(self.dense_enc1(x))
        x = self.dense_encbn1(x)
        x = F.relu(self.dense_enc2(x))
        x = self.dense_encbn2(x)
        x = F.relu(self.dense_enc3(x))
        x = self.dense_encbn3(x)
        mean = self.dense_encmean(x)
        var = F.softplus(self.dense_encvar(x))
        return mean, var

    def _sample_z(self, mean, var):
        epsilon = torch.randn(mean.shape).to(device)
        return mean + torch.sqrt(var) * epsilon

    def _decoder(self, z):
        x = F.relu(self.dense_dec1(z))
        x = self.dense_decbn1(x)
        x = F.relu(self.dense_dec2(x))
        x = self.dense_decbn2(x)
        x = torch.sigmoid(self.dense_dec3(x))
        return x

    def forward(self, x):
        mean, var = self._encoder(x)
        z = self._sample_z(mean, var)
        x = self._decoder(z)
        return x, z

    def loss(self, x):
        mean, var = self._encoder(x)
        # KL lossの計算
        KL = -0.5 * torch.mean(torch.sum(1 + torch_log(var) - mean**2 - var, dim=1))

        z = self._sample_z(mean, var)
        y = self._decoder(z)

        # reconstruction lossの計算
        reconstruction = torch.mean(torch.sum(x * torch_log(y) + (1 - x) * torch_log(1 - y), dim=1))

        return KL, -reconstruction 
z_dim = 17
n_epochs = 2000
init_lr = 5e-04
model = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=init_lr)
scheduler = OneCycleLR(optimizer, max_lr=init_lr, steps_per_epoch=len(dataloader_train), epochs=n_epochs, pct_start=0.2)
for epoch in range(n_epochs):
    losses = []
    KL_losses = []
    reconstruction_losses = []
    model.train()
    for x in dataloader_train:

        x = x.to(device)
        model.zero_grad()
        KL_loss, reconstruction_loss = model.loss(x)
        loss = KL_loss + reconstruction_loss

        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.cpu().detach().numpy())
        KL_losses.append(KL_loss.cpu().detach().numpy())
        reconstruction_losses.append(reconstruction_loss.cpu().detach().numpy())

    losses_val = []
    model.eval()
    for x in dataloader_valid:

        x = x.to(device)
        KL_loss, reconstruction_loss = model.loss(x)
        loss = KL_loss + reconstruction_loss

        losses_val.append(loss.cpu().detach().numpy())

    print('EPOCH:%d, Train Lower Bound:%lf, (%lf, %lf), Valid Lower Bound:%lf' %
          (epoch+1, np.average(losses), np.average(KL_losses), np.average(reconstruction_losses), np.average(losses_val)))
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?