GANをpixyzで実装してみた
GANもpixyzで実装できるらしいのでやってみました.
めんどくさいところが隠蔽されているので,10分でかけました.(3分は盛りました.)
可視化も含めた全実装はここにあげました.
公式実装を参考にしました.
network architecture
ネットワーク構造はここを参考にしました
from pixyz.distributions import Deterministic
import torch
import torch.nn as nn
class generator(Deterministic):
def __init__(self, input_dim=100, output_dim=1, input_size=28):
super(generator, self).__init__(cond_var=["z"], var=["x"])
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
nn.ReLU(),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Tanh(),
)
initialize_weights(self)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
x = self.deconv(x)
return {"x": x}
class discriminator(Deterministic):
def __init__(self, input_dim=1, output_dim=1, input_size=28):
super(discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.fc = nn.Sequential(
nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, self.output_dim),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
t = self.fc(x)
return {"t": t}
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
distribution
確率分布はノイズzの分布のみで,他は決定的な関数です.
generatorはzで周辺化しています.
from pixyz.distributions import DataDistribution
z_dim = 64
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")
# generative model
p_g = generator(input_dim=z_dim)
p = (p_g*prior).marginalize_var("z").to(device)
# p(x) = ∫p(x|z)p_prior(z)dz
# data distribution
p_data = DataDistribution(["x"]).to(device)
# p_data(x) (Data distribution)
# discriminator
d = discriminator().to(device)
# d(t|x) (Deterministic)
model
GANのクラスがあるので,定義した確率分布を入れるだけでオシマイです.
optimizerはGenerator, d_optimizerはDiscriminator用です.
JSシャノンダイバージェンスを最小化します.
from pixyz.models import GAN
model = GAN(p_data, p, d,
optimizer=optim.Adam, optimizer_params={"lr":0.0002},
d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})
# Loss function:
# mean(mean(AdversarialJS[p_data(x)||p(x)]))
for epoch in range(epoch_num):
for x, _ in tqdm(train_loader):
x = x.to(device)
loss, d_loss = model.train({"x": x})
もっと細かい実装もできるので,そのうちやります.
