はじめに
昨年の12月に発表されて話題をさらったpix2pixモデルだが、既にgithub上に幾つかの実装コードが存在する。
しかしCPUのみで動くようなコードが無いので作ってみた。
環境
OS:Ubuntu14.04
GPU:GTX1070(今回は使わない)
CUDA:8.0 RC(今回は使わない)
cuDNN:5.1(今回は使わない)
python:2.7.6
chainer:1.20.0.1
など
参考文献等
まずpix2pixの原論文は以下。
https://arxiv.org/abs/1611.07004
①今回参考にしたコードは2つ。まずmattyaさんのDCGANsコード。
https://github.com/mattya/chainer-DCGAN
②次にpreferdのpix2pixコード。
https://github.com/pfnet-research/chainer-pix2pix
こちらの画像取得部分のコード(facade_dataset.py)はほぼそのまま使わせてもらってます。
原論文に対する変更点
原論文に対する主な変更点は以下。
- generator及びdiscriminatorともに層の数、パラメータ等を大幅に削減
- 入出力画像のサイズは28x28に圧縮
- L2ノルムなどの正則化は使わない
- dropoutも使わない
(3,4に関してはbatch normalizationを使っているので、ある程度代替できてるか?)
モデルの概略
Generatorは以下のように変更した。
Discriminatorは以下のように変更した。
コード
コードは以下
import numpy as np
from PIL import Image
import os
import pylab
import chainer
from chainer import optimizers
from chainer import serializers
from chainer import Variable
import chainer.functions as F
import chainer.links as L
image_dir = './images_plane_s' #input data file
out_image_dir = './out_images_miniplane' #output image file
out_model_dir = './out_models_miniplane' #output model file
nz2 = 10 #dimention of noise
batchsize = 100
n_epoch = 200
# read all images
fs = os.listdir(image_dir)
print(len(fs))
dataset = []
for fn in fs:
f = open('%s/%s' % (image_dir, fn), 'rb')
img_bin = f.read()
dataset.append(img_bin)
f.close()
print(len(dataset))
# function for output image
def clip_img(x):
return np.float32(-1 if x < -1 else (1 if x > 1 else x))
zvis = (np.random.uniform(-1, 1, (100, nz2)).astype(np.float32))
try:
os.mkdir(out_image_dir)
os.mkdir(out_model_dir)
except:
pass
class Generator(chainer.Chain):
def __init__(self):
super(Generator, self).__init__(
l0s=L.Linear(nz2, 7 * 7 * 64),
dc1s=L.Deconvolution2D(64, 32, 4, stride=2, pad=1),
dc2s=L.Deconvolution2D(32, 3, 4, stride=2, pad=1),
bn0s=L.BatchNormalization(7 * 7 * 64),
bn1s=L.BatchNormalization(32),
)
def __call__(self, z, test=False):
h = self.l0s(z)
h = self.bn0s(h, test=test)
h = F.relu(h)
h = F.reshape(h, (z.data.shape[0], 64, 7, 7))
h = self.dc1s(h)
h = self.bn1s(h, test=test)
h = F.relu(h)
x = self.dc2s(h)
return x
class Discriminator(chainer.Chain):
def __init__(self):
super(Discriminator, self).__init__(
c0s=L.Convolution2D(3, 32, 4, stride=2, pad=1),
c1s=L.Convolution2D(32, 64, 4, stride=2, pad=1),
l2s=L.Linear(7 * 7 * 64, 2),
bn0s=L.BatchNormalization(32),
bn1s=L.BatchNormalization(64),
)
def __call__(self, x, test=False):
h = self.c0s(x)
h = F.leaky_relu(h, slope=0.2)
h = self.c1s(h)
h = self.bn1s(h, test=test)
h = F.leaky_relu(h, slope=0.2)
l = self.l2s(h)
return l
gen = Generator()
dis = Discriminator()
optimizer_gen = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_dis = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_gen.setup(gen)
optimizer_dis.setup(dis)
# training loop
for epoch in range(0, n_epoch):
sum_loss_dis = np.float32(0)
sum_loss_gen = np.float32(0)
for i in range(0, len(dataset), batchsize):
x2 = np.zeros((batchsize, 3, 28, 28), dtype=np.float32)
for j in range(batchsize):
# try:
rnd = np.random.randint(len(dataset))
img = np.asarray(Image.open(image_dir + '/'+ fs[rnd]).convert('RGB')).astype(np.float32)
img = img.transpose(2, 0, 1)
x2[j, :, :, :] = (img[:, :, :] - 128.0) / 128.0
z = Variable(np.random.uniform(-1, 1, (batchsize, nz2)).astype(np.float32))
x = gen(z)
yl = dis(x)
d_0 = Variable(np.zeros(batchsize, dtype=np.int32))
d_1 = Variable(np.ones(batchsize, dtype=np.int32))
loss_gen = F.softmax_cross_entropy(yl, d_0)
loss_dis = F.softmax_cross_entropy(yl, d_1)
yl2 = dis(x2)
t2 = Variable(np.zeros(batchsize, dtype=np.int32))
loss_dis += F.softmax_cross_entropy(yl2, t2)
optimizer_gen.zero_grads()
loss_gen.backward()
optimizer_gen.update()
optimizer_dis.zero_grads()
loss_dis.backward()
optimizer_dis.update()
sum_loss_gen += loss_gen.data
sum_loss_dis += loss_dis.data
print('epoch end', epoch, sum_loss_gen / len(dataset), sum_loss_dis / len(dataset))
if epoch % 10 == 0:
#draw image
pylab.rcParams['figure.figsize'] = (16.0, 16.0)
pylab.clf()
vissize = 100
z = zvis
z[50:, :] = (np.random.uniform(-1, 1, (50, nz2)).astype(np.float32))
z = Variable(z)
x = gen(z, test=True)
x = x.data
for i_ in range(100):
tmp = ((np.vectorize(clip_img)(x[i_, :, :, :]) + 1) / 2).transpose(1, 2, 0)
pylab.subplot(10, 10, i_ + 1)
pylab.imshow(tmp)
pylab.axis('off')
pylab.savefig('%s/vis_%d.png' % (out_image_dir, epoch))
#save model
serializers.save_hdf5("%s/dcgan_modeloss_dis_%d.h5" % (out_model_dir, epoch), dis)
serializers.save_hdf5("%s/dcgan_modeloss_gen_%d.h5" % (out_model_dir, epoch), gen)
serializers.save_hdf5("%s/dcgan_state_dis_%d.h5" % (out_model_dir, epoch), optimizer_dis)
serializers.save_hdf5("%s/dcgan_state_gen_%d.h5" % (out_model_dir, epoch), optimizer_gen)
また画像の読み込み等は以下のコード。これは上記②のコードを一部変更している。
import os
import numpy
from PIL import Image
import six
import numpy as np
from io import BytesIO
import os
import pickle
import json
import numpy as np
import skimage.io as io
from chainer.dataset import dataset_mixin
# download `BASE` dataset from http://cmp.felk.cvut.cz/~tylecr1/facade/
class FacadeDataset(dataset_mixin.DatasetMixin):
def __init__(self, dataDir='./facade/base', data_range=(1,300)):
print("load dataset start")
print(" from: %s"%dataDir)
print(" range: [%d, %d)"%(data_range[0], data_range[1]))
self.dataDir = dataDir
self.dataset = []
for i in range(data_range[0],data_range[1]):
#print("file_name", dataDir+"/cmp_b%04d.jpg"%i)
img = Image.open(dataDir+"/cmp_b%04d.jpg"%i)
label = Image.open(dataDir+"/cmp_b%04d.png"%i)
w,h = img.size
#print("w , h = ", w, h)
#r = 286/min(w,h)
r = 31 / float(min(w, h))
# resize images so that min(w, h) == 286
#print("r = ", r)
#print("min(w,h) = ", min(w,h))
img = img.resize((int(r*w), int(r*h)), Image.BILINEAR)
#print("img.size = ", img.size)
label = label.resize((int(r*w), int(r*h)), Image.NEAREST)
#debug
img = np.asarray(img)
#print("type(img) = ", type(img))
#print("img.shape = ", img.shape)
img = np.asarray(img).astype("f").transpose(2,0,1)/128.0-1.0
label_ = np.asarray(label)-1 # [0, 12)
label = np.zeros((12, img.shape[1], img.shape[2])).astype("i")
for j in range(12):
label[j,:] = label_==j
#print("label.shape = ", label.shape)
#print("label[0][[10][10] = ", label[0][10][10])
self.dataset.append((img,label))
print("load dataset done")
def __len__(self):
return len(self.dataset)
# return (label, img)
def get_example(self, i, crop_width=28):
_,h,w = self.dataset[i][0].shape
x_l = np.random.randint(0,w-crop_width)
x_r = x_l+crop_width
y_l = np.random.randint(0,h-crop_width)
y_r = y_l+crop_width
# print("np.max(self.dataset[i][1][:,y_l:y_r,x_l:x_r]) = ", np.max(self.dataset[i][1][:,y_l:y_r,x_l:x_r]))
# print("np.max(self.dataset[i][0][:,y_l:y_r,x_l:x_r]) = ", np.max(self.dataset[i][0][:,y_l:y_r,x_l:x_r]))
# print("np.min(self.dataset[i][1][:,y_l:y_r,x_l:x_r]) = ", np.min(self.dataset[i][1][:,y_l:y_r,x_l:x_r]))
# print("np.min(self.dataset[i][0][:,y_l:y_r,x_l:x_r]) = ", np.min(self.dataset[i][0][:,y_l:y_r,x_l:x_r]))
return self.dataset[i][1][:,y_l:y_r,x_l:x_r], self.dataset[i][0][:,y_l:y_r,x_l:x_r]
現在諸事情で個人用GitHubが使えなくなっているが、復活したらそちらにもupします。