21
34

More than 5 years have passed since last update.

Chainerでモノクロ画像のカラー化を学習してみる

Posted at

概要

  • Chainerでモノクロ画像のカラー化の学習を簡単に実装してみます。
  • 最近のChainerは抽象的に書けるようになってるので、Trainer等を使って実装します。
  • 全結合NNと畳込みNN、及び活性化関数の違いによる仕上がりの違いを比較してみます。

データセット

データセットにはCIFAR-100を利用します。Chainerには簡単に利用できるデータセットが幾つか用意されており、MNISTやCIFAR-100といったものが使えます。今回は手っ取り早く沢山のカラー画像が用意できるCIFAR-100にしましたが、本来はもっとより多くの種類の画像を含むデータセットが適切かと思います。

ChainerでのCIFAR-100の利用は以下のように実装します。

train, test = chainer.datasets.get_cifar100(withlabel=False)

Chainerの教師あり学習のデータセットは(学習データ, 教師データ)の配列をiteratorに渡して実装します。モノクロ画像のカラー化を学習するにあたって必要なのは(モノクロ画像, カラー画像)の配列です。今回利用するchainer.datasets.get_cifar100()からは(画像, ラベル)というタプルの配列(っぽいもの)が得られるのですが、ラベルは不要なのでwithlabelをFalseにし、画像の配列だけを得ます。

次に、得られたカラー画像のデータセットを(モノクロ画像, カラー画像)に改造します。Chainerでデータセットを自作する際はchainer.datasets.TupleDataset(学習データ配列, 教師データ配列)で実装することが多いと思います。が、カラー画像からのモノクロ画像生成は簡単にできるので、今回はchainer.dataset.DatasetMixinを継承したクラスを使い、ミニバッチを作成する直前でモノクロ画像を用意するやり方で実装します。こちらのサンプルが参考になります。画像のクロップやノイズをのせるといった加工もこのやり方で行うのが楽かと思われます。

class PreprocessedDataset(chainer.dataset.DatasetMixin):
    def __init__(self, base_image_dataset):
        self.base = base_image_dataset

    def __len__(self):
        return len(self.base)

    def get_example(self, i):
        color_image = self.base[i]
        gray_image = np.ndarray((32, 32), dtype=np.float32)
        for ch in range(3):
            gray_image = (
                0.298912*color_image[0]
                + 0.586611*color_image[1]
                + 0.114478*color_image[2]
            )
        return gray_image, color_image

ネットワーク構成

今回は2種類のNNを試してみます。一つは単純に全結合層を繋げたもの、もう一つはConvolution層の後ろにDeconvolution層を繋げたものです。

class AIC_FC(chainer.Chain):
    def __init__(self, n_units):
        initializer = chainer.initializers.HeNormal()
        super(AIC_FC, self).__init__(
            fc_in = L.Linear(None, n_units),
            bn1 = L.BatchNormalization(n_units),
            fc2 = L.Linear(None, n_units),
            bn2 = L.BatchNormalization(n_units),
            fc_out = L.Linear(None, 32*32*3)
        )

    def __call__(self, x, t):
        y = self.colorize(x)
        loss = F.mean_squared_error(y, t)
        chainer.reporter.report({
            'loss': loss
        })
        return loss

    def colorize(self, x, test=False):
        h = F.elu(self.bn1(self.fc_in(x), test=test))
        h = F.elu(self.bn2(self.fc2(h), test=test))
        y = F.reshape(self.fc_out(h), (h.shape[0], 3, 32, 32))
        return y

class AIC_DC(chainer.Chain):
    def __init__(self, n_ch):
        initializer = chainer.initializers.HeNormal()
        super(AIC_DC, self).__init__(
            cv_in = L.Convolution2D(1, n_ch//4, 4, 2, 1),
            bn1 = L.BatchNormalization(n_ch//4),
            cv1 = L.Convolution2D(n_ch//4, n_ch//2, 4, 2, 1),
            bn2 = L.BatchNormalization(n_ch//2),
            cv2 = L.Convolution2D(n_ch//2, n_ch, 4, 2, 1),
            bn3 = L.BatchNormalization(n_ch),
            cv3 = L.Convolution2D(n_ch, n_ch, 4, 2, 1),
            bn4 = L.BatchNormalization(n_ch),
            dc1 = L.Deconvolution2D(n_ch, n_ch, 4, 2, 1),
            bn5 = L.BatchNormalization(n_ch),
            dc2 = L.Deconvolution2D(n_ch, n_ch//2, 4, 2, 1),
            bn6 = L.BatchNormalization(n_ch//2),
            dc3 = L.Deconvolution2D(n_ch//2, n_ch//4, 4, 2, 1),
            bn7 = L.BatchNormalization(n_ch//4),
            dc_out = L.Deconvolution2D(n_ch//4, 3, 4, 2, 1, outsize=(32, 32))
        )

    def __call__(self, x, t):
        y = self.colorize(x)
        loss = F.mean_squared_error(y, t)
        chainer.reporter.report({
            'loss': loss
        })
        return loss

    def colorize(self, x, test=False):
        h = F.reshape(x, (x.shape[0], 1, 32, 32))
        h = F.elu(self.bn1(self.cv_in(h), test=test))
        h = F.elu(self.bn2(self.cv1(h), test=test))
        h = F.elu(self.bn3(self.cv2(h), test=test))
        h = F.elu(self.bn4(self.cv3(h), test=test))
        h = F.elu(self.bn5(self.dc1(h), test=test))
        h = F.elu(self.bn6(self.dc2(h), test=test))
        h = F.elu(self.bn7(self.dc3(h), test=test))
        y = self.dc_out(h)
        return y

全結合NNを深くしてみたところ、待てど暮らせどぼやけた画像しか生成されない(収束が遅い?)ので、浅めにしました。

ちなみに実際のカラー化NNはもっと複雑な設計になっているようです。(参考

学習の実装

Chainerにおける学習の実装の流れは

  1. モデルの作成
  2. Optimizerの設定
  3. Datasetの用意
  4. DatasetからIteratorを作成
  5. Updaterを設定
  6. Trainerを設定

となります。OptimizerはAdamの他、SGDやMomentumSGDのような基本的なものも用意されており、UpdaterもGANのような複雑な損失計算を行わない問題についてはStandardUpdaterで十分かと思われます。全部自力で書いてた時代に比べると、デフォルトでいい感じのものが用意されているので大分楽できます。改造も必要な部分だけで済むので、他の人が見たときの分かりやすさが改善しており、有り難みを感じます。

モデルのテストを実装

今回のような問題はlossの値だけでは学習の様子が分かりにくいので視覚化したくなります。そこでTrainer Extensionでモデルのテストを実装します。Extensionはchainer.training.extention.Extension等を継承するか、chainer.training.make_extension()を利用して作成します。モデルのテストと画像の保存をchainer.training.make_extension()で以下の通り実装します。(scipy.miscのimsaveをimportしています。)

訓練データとは別にテスト用画像を用意しています。

@chainer.training.make_extension(trigger=(1, 'epoch'))
def test_model(trainer):
    colorized_img = chainer.cuda.to_cpu(F.clipped_relu(model.colorize(test_img, test=True), z=1.0).data)
    imsave(
        'test_colorized{}.png'.format(trainer.updater.epoch),
        colorized_img
        .transpose(0, 2, 3, 1)
        .reshape((8, 8, 32, 32, 3))
        .transpose(1, 2, 0, 3, 4)
        .reshape(8*32, 8*32, 3)
    )
trainer.extend(test_model)

学習結果

学習の結果を紹介します。今回は全結合NNと畳み込みNNの2つを試したので、30 epoch学習させた時点での学習の様子と合わせて比較していきます。全結合NNのn_unitsは2048、畳込みNNのn_chは512で学習しました。

カラー画像

test.png

モノクロ画像

test_gray.png

30 epoch後

全結合NN
30epoch3L2048units.png

畳み込みNN
30epochELU512ch.png

全結合NNは全体的にざらついた感じの画像になりやすいようです。単純なNNでも結構色が乗っていて意外でしたが、畳み込みNNのほうがより鮮やかで綺麗な画像が出力できていると感じます。

一つ一つの画像を見ていくと、空や海はうまく認識して綺麗に色を載せてくれる傾向にありそうです。ですが青空と夕焼けの区別は難しいみたいですね。小動物が写っていると思しき一番右下の画像については、地面を認識して茶色や草の色を再現しています。が、正解と比べるとちょっと草生やしすぎですね。

活性化関数の違いによる変化

上に挙げた画像は活性化関数にeluを使用した際のものです。活性化関数にreluやleaky_reluを使用すると、仕上がりにどのような変化があるかも調べてみました。n_chが512の畳み込みNN, 30 epochでの結果のみ紹介します。

relu
30epochRelu512ch.png

leaky_relu
30epoch512ch.png

elu
30epochELU512ch.png

reluは鮮やかに仕上がっていますが若干ぼやけた印象です。eluやleaky reluは画像の鮮明さでreluより優秀そうです。leaky reluはeluとreluの中間くらいの印象を感じます。

使用したコードの全体

#! /usr/bin/env python
# coding : utf-8

import argparse
import numpy as np
from scipy.misc import imsave
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions 


class PreprocessedDataset(chainer.dataset.DatasetMixin):
    def __init__(self, base_image_dataset):
        self.base = base_image_dataset

    def __len__(self):
        return len(self.base)

    def get_example(self, i):
        color_image = self.base[i]
        gray_image = np.ndarray((32, 32), dtype=np.float32)
        for ch in range(3):
            # 輝度を計算し、モノクロ画像を作成
            gray_image = (
                0.298912*color_image[0]
                + 0.586611*color_image[1]
                + 0.114478*color_image[2]
            )
        return gray_image, color_image

class AIC_FC(chainer.Chain):
    def __init__(self, n_units):
        initializer = chainer.initializers.HeNormal()
        super(AIC_FC, self).__init__(
            fc_in = L.Linear(None, n_units),
            bn1 = L.BatchNormalization(n_units),
            fc2 = L.Linear(None, n_units),
            bn2 = L.BatchNormalization(n_units),
            fc_out = L.Linear(None, 32*32*3)
        )

    def __call__(self, x, t):
        y = self.colorize(x)
        loss = F.mean_squared_error(y, t)
        chainer.reporter.report({
            'loss': loss
        })
        return loss

    def colorize(self, x, test=False):
        h = F.elu(self.bn1(self.fc_in(x), test=test))
        h = F.elu(self.bn2(self.fc2(h), test=test))
        y = F.reshape(self.fc_out(h), (h.shape[0], 3, 32, 32))
        return y

class AIC_DC(chainer.Chain):
    def __init__(self, n_ch):
        initializer = chainer.initializers.HeNormal()
        super(AIC_DC, self).__init__(
            cv_in = L.Convolution2D(1, n_ch//4, 4, 2, 1),
            bn1 = L.BatchNormalization(n_ch//4),
            cv1 = L.Convolution2D(n_ch//4, n_ch//2, 4, 2, 1),
            bn2 = L.BatchNormalization(n_ch//2),
            cv2 = L.Convolution2D(n_ch//2, n_ch, 4, 2, 1),
            bn3 = L.BatchNormalization(n_ch),
            cv3 = L.Convolution2D(n_ch, n_ch, 4, 2, 1),
            bn4 = L.BatchNormalization(n_ch),
            dc1 = L.Deconvolution2D(n_ch, n_ch, 4, 2, 1),
            bn5 = L.BatchNormalization(n_ch),
            dc2 = L.Deconvolution2D(n_ch, n_ch//2, 4, 2, 1),
            bn6 = L.BatchNormalization(n_ch//2),
            dc3 = L.Deconvolution2D(n_ch//2, n_ch//4, 4, 2, 1),
            bn7 = L.BatchNormalization(n_ch//4),
            dc_out = L.Deconvolution2D(n_ch//4, 3, 4, 2, 1, outsize=(32, 32))
        )

    def __call__(self, x, t):
        y = self.colorize(x)
        loss = F.mean_squared_error(y, t)
        chainer.reporter.report({
            'loss': loss
        })
        return loss

    def colorize(self, x, test=False):
        # Convolution層に入力するため、ndimが4になるようにreshape
        h = F.reshape(x, (x.shape[0], 1, 32, 32))
        h = F.elu(self.bn1(self.cv_in(h), test=test))
        h = F.elu(self.bn2(self.cv1(h), test=test))
        h = F.elu(self.bn3(self.cv2(h), test=test))
        h = F.elu(self.bn4(self.cv3(h), test=test))
        h = F.elu(self.bn5(self.dc1(h), test=test))
        h = F.elu(self.bn6(self.dc2(h), test=test))
        h = F.elu(self.bn7(self.dc3(h), test=test))
        y = self.dc_out(h)
        return y


def main():
    parser = argparse.ArgumentParser(description='Automatic Image Colorization')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=30,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--n_ch', '-nc', type=int, default=1024,
                        help='Number of channels')
    parser.add_argument('--n_units', '-nu', type=int, default=0,
                        help='Number of units')
    args = parser.parse_args()
    print('# GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))

    if args.n_units > 0:
        print('# n_units: {}\n'.format(args.n_units))
        model = AIC_FC(args.n_units)
    else:
        print('# n_ch: {}\n'.format(args.n_ch))
        model = AIC_DC(args.n_ch)
    if args.gpu >= 0:
        chainer.cuda.get_device().use()
        model.to_gpu()

    opt = chainer.optimizers.Adam()
    opt.setup(model)

    train, test = chainer.datasets.get_cifar100(withlabel=False)
    test_img = (
        0.298912*test[:64,0]
        + 0.586611*test[:64,1]
        + 0.114478*test[:64,2]
    )
    # 64枚の画像を8x8に並んだ一枚の画像として保存する
    imsave(
        'test.png',
        test[:64]
        .transpose(0, 2, 3, 1)
        .reshape((8, 8, 32, 32, 3))
        .transpose(1, 2, 0, 3, 4)
        .reshape(8*32, 8*32, 3)
    )
    imsave(
        'test_gray.png',
        test_img
        .reshape((8, 8, 32, 32))
        .transpose(1, 2, 0, 3)
        .reshape(8*32, 8*32)
    )
    if args.gpu >= 0:
        test_img = chainer.cuda.to_gpu(test_img)


    dataset = PreprocessedDataset(train)
    iterator = chainer.iterators.MultiprocessIterator(dataset, args.batchsize)

    updater = chainer.training.StandardUpdater(iterator, opt, device=args.gpu)
    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'))

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport([
        'epoch', 'loss', 'elapsed_time'
    ]))
    @chainer.training.make_extension(trigger=(10, 'epoch'))
    def test_model(trainer):
        # 値域を0~1にするため、clipped_reluを通す
        colorized_img = chainer.cuda.to_cpu(F.clipped_relu(model.colorize(test_img, test=True), z=1.0).data)
        imsave(
            'test_colorized{}.png'.format(trainer.updater.epoch),
            colorized_img
            .transpose(0, 2, 3, 1)
            .reshape((8, 8, 32, 32, 3))
            .transpose(1, 2, 0, 3, 4)
            .reshape(8*32, 8*32, 3)
        )
    trainer.extend(test_model)
    trainer.extend(extensions.ProgressBar(update_interval=100))

    trainer.run()

if __name__ == '__main__':
    main()
21
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
34