Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
3
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

@kogepan102

pixyzでGANを3分で実装する

GANをpixyzで実装してみた

GANもpixyzで実装できるらしいのでやってみました.
めんどくさいところが隠蔽されているので,10分でかけました.(3分は盛りました.)

mnist_gif_GAN.gif

可視化も含めた全実装はここにあげました.
公式実装を参考にしました.

network architecture

ネットワーク構造はここを参考にしました

from pixyz.distributions import Deterministic
import torch
import torch.nn as nn

class generator(Deterministic):
    def __init__(self, input_dim=100, output_dim=1, input_size=28):
        super(generator, self).__init__(cond_var=["z"], var=["x"])
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return {"x": x}

class discriminator(Deterministic):
    def __init__(self, input_dim=1, output_dim=1, input_size=28):
        super(discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        initialize_weights(self)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        t = self.fc(x)

        return {"t": t}

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

distribution

確率分布はノイズzの分布のみで,他は決定的な関数です.
generatorはzで周辺化しています.

from pixyz.distributions import DataDistribution

z_dim = 64

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

# generative model
p_g = generator(input_dim=z_dim)
p = (p_g*prior).marginalize_var("z").to(device)
# p(x) = ∫p(x|z)p_prior(z)dz

# data distribution
p_data = DataDistribution(["x"]).to(device)
# p_data(x) (Data distribution)

# discriminator
d = discriminator().to(device)
# d(t|x) (Deterministic)

model

GANのクラスがあるので,定義した確率分布を入れるだけでオシマイです.
optimizerはGenerator, d_optimizerはDiscriminator用です.
JSシャノンダイバージェンスを最小化します.

from pixyz.models import GAN

model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})

# Loss function: 
#  mean(mean(AdversarialJS[p_data(x)||p(x)])) 


for epoch in range(epoch_num):
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss, d_loss = model.train({"x": x})

もっと細かい実装もできるので,そのうちやります.

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
3
Help us understand the problem. What are the problem?