LoginSignup
5

More than 5 years have passed since last update.

pix2pixモデルCPU版を作ってみる

Last updated at Posted at 2017-09-01

はじめに

昨年の12月に発表されて話題をさらったpix2pixモデルだが、既にgithub上に幾つかの実装コードが存在する。
しかしCPUのみで動くようなコードが無いので作ってみた。

環境

OS:Ubuntu14.04
GPU:GTX1070(今回は使わない)
CUDA:8.0 RC(今回は使わない)
cuDNN:5.1(今回は使わない)
python:2.7.6
chainer:1.20.0.1
など

参考文献等

まずpix2pixの原論文は以下。
https://arxiv.org/abs/1611.07004
①今回参考にしたコードは2つ。まずmattyaさんのDCGANsコード。
https://github.com/mattya/chainer-DCGAN
②次にpreferdのpix2pixコード。
https://github.com/pfnet-research/chainer-pix2pix
こちらの画像取得部分のコード(facade_dataset.py)はほぼそのまま使わせてもらってます。

原論文に対する変更点

原論文に対する主な変更点は以下。
1. generator及びdiscriminatorともに層の数、パラメータ等を大幅に削減
2. 入出力画像のサイズは28x28に圧縮
3. L2ノルムなどの正則化は使わない
4. dropoutも使わない
(3,4に関してはbatch normalizationを使っているので、ある程度代替できてるか?)

モデルの概略

Generatorは以下のように変更した。
generator_pix2pix.png
Discriminatorは以下のように変更した。
discriminator_pix2pix.png

コード

コードは以下

train_pix2pix_cpu.py
import numpy as np
from PIL import Image
import os
import pylab

import chainer
from chainer import optimizers
from chainer import serializers
from chainer import Variable
import chainer.functions as F
import chainer.links as L

image_dir = './images_plane_s' #input data file
out_image_dir = './out_images_miniplane' #output image file
out_model_dir = './out_models_miniplane' #output model file

nz2 = 10 #dimention of noise
batchsize = 100
n_epoch = 200

# read all images
fs = os.listdir(image_dir)
print(len(fs))
dataset = []
for fn in fs:
    f = open('%s/%s' % (image_dir, fn), 'rb')
    img_bin = f.read()
    dataset.append(img_bin)
    f.close()
print(len(dataset))

#function for output image
def clip_img(x):
    return np.float32(-1 if x < -1 else (1 if x > 1 else x))
zvis = (np.random.uniform(-1, 1, (100, nz2)).astype(np.float32))
try:
    os.mkdir(out_image_dir)
    os.mkdir(out_model_dir)
except:
    pass


class Generator(chainer.Chain):
    def __init__(self):
        super(Generator, self).__init__(
            l0s=L.Linear(nz2, 7 * 7 * 64),
            dc1s=L.Deconvolution2D(64, 32, 4, stride=2, pad=1),
            dc2s=L.Deconvolution2D(32, 3, 4, stride=2, pad=1),
            bn0s=L.BatchNormalization(7 * 7 * 64),
            bn1s=L.BatchNormalization(32),
        )

    def __call__(self, z, test=False):
        h = self.l0s(z)
        h = self.bn0s(h, test=test)
        h = F.relu(h)
        h = F.reshape(h, (z.data.shape[0], 64, 7, 7))
        h = self.dc1s(h)
        h = self.bn1s(h, test=test)
        h = F.relu(h)
        x = self.dc2s(h)
        return x


class Discriminator(chainer.Chain):
    def __init__(self):
        super(Discriminator, self).__init__(
            c0s=L.Convolution2D(3, 32, 4, stride=2, pad=1),
            c1s=L.Convolution2D(32, 64, 4, stride=2, pad=1),
            l2s=L.Linear(7 * 7 * 64, 2),
            bn0s=L.BatchNormalization(32),
            bn1s=L.BatchNormalization(64),
        )

    def __call__(self, x, test=False):
        h  = self.c0s(x)
        h = F.leaky_relu(h, slope=0.2)
        h = self.c1s(h)
        h = self.bn1s(h, test=test)
        h = F.leaky_relu(h, slope=0.2)
        l = self.l2s(h)
        return l


gen = Generator()
dis = Discriminator()

optimizer_gen = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_dis = optimizers.Adam(alpha=0.0002, beta1=0.5)
optimizer_gen.setup(gen)
optimizer_dis.setup(dis)


#training loop
for epoch in range(0, n_epoch):
    sum_loss_dis = np.float32(0)
    sum_loss_gen = np.float32(0)

    for i in range(0, len(dataset), batchsize):

        x2 = np.zeros((batchsize, 3, 28, 28), dtype=np.float32)
        for j in range(batchsize):
            # try:
            rnd = np.random.randint(len(dataset))
            img = np.asarray(Image.open(image_dir + '/'+ fs[rnd]).convert('RGB')).astype(np.float32)
            img = img.transpose(2, 0, 1)
            x2[j, :, :, :] = (img[:, :, :] - 128.0) / 128.0

        z = Variable(np.random.uniform(-1, 1, (batchsize, nz2)).astype(np.float32))
        x = gen(z)
        yl = dis(x)
        d_0 = Variable(np.zeros(batchsize, dtype=np.int32))
        d_1 = Variable(np.ones(batchsize, dtype=np.int32))
        loss_gen = F.softmax_cross_entropy(yl, d_0)
        loss_dis = F.softmax_cross_entropy(yl, d_1)

        yl2 = dis(x2)
        t2 = Variable(np.zeros(batchsize, dtype=np.int32))
        loss_dis += F.softmax_cross_entropy(yl2, t2)

        optimizer_gen.zero_grads()
        loss_gen.backward()
        optimizer_gen.update()

        optimizer_dis.zero_grads()
        loss_dis.backward()
        optimizer_dis.update()

        sum_loss_gen += loss_gen.data
        sum_loss_dis += loss_dis.data

    print('epoch end', epoch, sum_loss_gen / len(dataset), sum_loss_dis / len(dataset))

    if epoch % 10 == 0:
        #draw image
        pylab.rcParams['figure.figsize'] = (16.0, 16.0)
        pylab.clf()
        vissize = 100
        z = zvis
        z[50:, :] = (np.random.uniform(-1, 1, (50, nz2)).astype(np.float32))
        z = Variable(z)
        x = gen(z, test=True)
        x = x.data
        for i_ in range(100):
            tmp = ((np.vectorize(clip_img)(x[i_, :, :, :]) + 1) / 2).transpose(1, 2, 0)
            pylab.subplot(10, 10, i_ + 1)
            pylab.imshow(tmp)
            pylab.axis('off')
        pylab.savefig('%s/vis_%d.png' % (out_image_dir, epoch))

        #save model
        serializers.save_hdf5("%s/dcgan_modeloss_dis_%d.h5" % (out_model_dir, epoch), dis)
        serializers.save_hdf5("%s/dcgan_modeloss_gen_%d.h5" % (out_model_dir, epoch), gen)
        serializers.save_hdf5("%s/dcgan_state_dis_%d.h5" % (out_model_dir, epoch), optimizer_dis)
        serializers.save_hdf5("%s/dcgan_state_gen_%d.h5" % (out_model_dir, epoch), optimizer_gen)

また画像の読み込み等は以下のコード。これは上記②のコードを一部変更している。

facade_dataset2.py
import os

import numpy
from PIL import Image
import six

import numpy as np

from io import BytesIO
import os
import pickle
import json
import numpy as np

import skimage.io as io

from chainer.dataset import dataset_mixin

# download `BASE` dataset from http://cmp.felk.cvut.cz/~tylecr1/facade/
class FacadeDataset(dataset_mixin.DatasetMixin):
    def __init__(self, dataDir='./facade/base', data_range=(1,300)):
        print("load dataset start")
        print("    from: %s"%dataDir)
        print("    range: [%d, %d)"%(data_range[0], data_range[1]))
        self.dataDir = dataDir
        self.dataset = []
        for i in range(data_range[0],data_range[1]):
            #print("file_name", dataDir+"/cmp_b%04d.jpg"%i)
            img = Image.open(dataDir+"/cmp_b%04d.jpg"%i)
            label = Image.open(dataDir+"/cmp_b%04d.png"%i)
            w,h = img.size
            #print("w , h = ", w, h)
            #r = 286/min(w,h)
            r = 31 / float(min(w, h))
            # resize images so that min(w, h) == 286
            #print("r = ", r)
            #print("min(w,h) = ", min(w,h))
            img = img.resize((int(r*w), int(r*h)), Image.BILINEAR)
            #print("img.size = ", img.size)
            label = label.resize((int(r*w), int(r*h)), Image.NEAREST)
            #debug
            img = np.asarray(img)
            #print("type(img) = ", type(img))

            #print("img.shape = ", img.shape)

            img = np.asarray(img).astype("f").transpose(2,0,1)/128.0-1.0
            label_ = np.asarray(label)-1  # [0, 12)
            label = np.zeros((12, img.shape[1], img.shape[2])).astype("i")
            for j in range(12):
                label[j,:] = label_==j
            #print("label.shape = ", label.shape)
            #print("label[0][[10][10] = ", label[0][10][10])
            self.dataset.append((img,label))
        print("load dataset done")

    def __len__(self):
        return len(self.dataset)

    # return (label, img)
    def get_example(self, i, crop_width=28):
        _,h,w = self.dataset[i][0].shape
        x_l = np.random.randint(0,w-crop_width)
        x_r = x_l+crop_width
        y_l = np.random.randint(0,h-crop_width)
        y_r = y_l+crop_width
        # print("np.max(self.dataset[i][1][:,y_l:y_r,x_l:x_r]) = ", np.max(self.dataset[i][1][:,y_l:y_r,x_l:x_r]))
        # print("np.max(self.dataset[i][0][:,y_l:y_r,x_l:x_r]) = ", np.max(self.dataset[i][0][:,y_l:y_r,x_l:x_r]))
        # print("np.min(self.dataset[i][1][:,y_l:y_r,x_l:x_r]) = ", np.min(self.dataset[i][1][:,y_l:y_r,x_l:x_r]))
        # print("np.min(self.dataset[i][0][:,y_l:y_r,x_l:x_r]) = ", np.min(self.dataset[i][0][:,y_l:y_r,x_l:x_r]))

        return self.dataset[i][1][:,y_l:y_r,x_l:x_r], self.dataset[i][0][:,y_l:y_r,x_l:x_r]

現在諸事情で個人用GitHubが使えなくなっているが、復活したらそちらにもupします。

結果

学習前の出力画像と500回学習した場合の出力画像は以下。
vis_0.png
学習前(左からイラスト、本物の画像、生成された画像)

vis_500.png
500回学習後(左からイラスト、本物の画像、生成された画像)

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5