5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

pixyzでGANを3分で実装する

Last updated at Posted at 2019-01-20

GANをpixyzで実装してみた

GANもpixyzで実装できるらしいのでやってみました.
めんどくさいところが隠蔽されているので,10分でかけました.(3分は盛りました.)

mnist_gif_GAN.gif

可視化も含めた全実装はここにあげました.
公式実装を参考にしました.

network architecture

ネットワーク構造はここを参考にしました

from pixyz.distributions import Deterministic
import torch
import torch.nn as nn

class generator(Deterministic):
    def __init__(self, input_dim=100, output_dim=1, input_size=28):
        super(generator, self).__init__(cond_var=["z"], var=["x"])
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return {"x": x}

class discriminator(Deterministic):
    def __init__(self, input_dim=1, output_dim=1, input_size=28):
        super(discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        initialize_weights(self)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        t = self.fc(x)

        return {"t": t}
    
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

distribution

確率分布はノイズzの分布のみで,他は決定的な関数です.
generatorはzで周辺化しています.

from pixyz.distributions import DataDistribution

z_dim = 64

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

# generative model
p_g = generator(input_dim=z_dim)
p = (p_g*prior).marginalize_var("z").to(device)
# p(x) = ∫p(x|z)p_prior(z)dz

# data distribution
p_data = DataDistribution(["x"]).to(device)
# p_data(x) (Data distribution)

# discriminator
d = discriminator().to(device)
# d(t|x) (Deterministic)

model

GANのクラスがあるので,定義した確率分布を入れるだけでオシマイです.
optimizerはGenerator, d_optimizerはDiscriminator用です.
JSシャノンダイバージェンスを最小化します.

from pixyz.models import GAN

model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})

# Loss function: 
#  mean(mean(AdversarialJS[p_data(x)||p(x)])) 


for epoch in range(epoch_num):
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss, d_loss = model.train({"x": x})

もっと細かい実装もできるので,そのうちやります.

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?