7
2

More than 3 years have passed since last update.

ruby-dnnとディープラーニングでリンゴをオレンジに変換してみた(Cycle-GAN)

Posted at

はじめに

今回はruby-dnnでCycle-GANを動かしてリンゴをオレンジに変換してみたいと思います。
RubyとCPUでCycle-GANという時点でもう無理ゲー感半端ないです😇

コード全文はhttps://github.com/unagiootoro/apple2orange-cycleganにあります。

使用ライブラリ/バージョンなど

Ruby ... v2.6.5
ruby-dnn ... v1.1.4
Numo::NArray ... v0.9.1.5
Numo::Linalg ... v0.1.4
rubyzip ... v2.2.0

Cycle-GANざっくり解説

Cycle-GANとは、二つのドメインの異なる画像間でそれぞれ相互に変換ができるようにするモデルです。
画像の変換というと、前回やったPix2pixも画像変換ですが、Pix2pixは入力画像に対する出力画像が一対一でなければならないのに対して、Cycle-GANでは、一対一のペアがなくても変換元と変換先の画像さえあれば変換できるのが特徴です。

Cycle-GANでは、「DCGAN-A」「DCGAN-B」「Cycle Consistency Loss」の3つによって変換が行われます。イメージとしては、以下の図のような感じです。
(画像はペイントで作った手抜きです🤡)
img1.png

DCGAN-Aでは、リンゴがオレンジになることを学習させます。
逆にDCGAN-Bでは、オレンジがリンゴになることを学習させます。
しかし、これだけだと、リンゴの画像をオレンジに変換したとき、元のリンゴの画像の形状を保つことができないので、Cycle Consistency Lossが必要になります。

Cycle Consistency Loss

Cycle Consistency Lossは、DCGAN-Aでリンゴをオレンジに変換した画像にDCGAN-Bを適用すると元のリンゴに戻ることを学習させるために使用します。同じように、DCGAN-Bでオレンジをリンゴに変換した画像にDCGAN-Aを適用すると元のオレンジに戻ることも学習させます。

これによって、元のリンゴの画像の形状を保ったまま必要な箇所だけオレンジに変換することができるようになります。

ruby-dnnでCycle-GAN

モデル定義

Cycle-GANのモデル定義です。ほぼ前回やったPix2pixのモデルを使いまわしています。 1

各クラスの役割は、以下の通りです。

Generator: 入力画像から変換先の画像を生成します。

Discriminator: Generatorが生成した画像を受け取り、それが本物か生成された画像かを判断できるように学習させるためのモデルです。

DCGAN: Discriminatorを騙せるような画像を生成できるようにGenerator-Aを学習させるためのモデルです。また、Generator-Aが生成した画像にGenerator-Bを適用したとき、元の画像に戻るように学習させます。

長いので折りたたんでいます
class Generator < Model
  def initialize(input_shape, base_num_filters)
    super()
    @input_shape = input_shape
    @cv1 = Conv2D.new(base_num_filters, 4, padding: true)
    @cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
    @cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
    @cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
    @cv5 = Conv2D.new(base_num_filters * 2, 4, padding: true)
    @cv6 = Conv2D.new(base_num_filters, 4, padding: true)
    @cv7 = Conv2D.new(base_num_filters, 4, padding: true)
    @cv8 = Conv2D.new(3, 4, padding: true)
    @cvt1 = Conv2DTranspose.new(base_num_filters * 2, 4, strides: 2, padding: true)
    @cvt2 = Conv2DTranspose.new(base_num_filters, 4, strides: 2, padding: true)
    @bn1 = BatchNormalization.new
    @bn2 = BatchNormalization.new
    @bn3 = BatchNormalization.new
    @bn4 = BatchNormalization.new
    @bn5 = BatchNormalization.new
    @bn6 = BatchNormalization.new
    @bn7 = BatchNormalization.new
    @bn8 = BatchNormalization.new
  end

  def forward(x)
    input = InputLayer.new(@input_shape).(x)
    x = @cv1.(input)
    x = @bn1.(x)
    h1 = LeakyReLU.(x, 0.2)

    x = @cv2.(h1)
    x = @bn2.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv3.(x)
    x = @bn3.(x)
    h2 = LeakyReLU.(x, 0.2)

    x = @cv4.(h2)
    x = @bn4.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv5.(x)
    x = @bn5.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cvt1.(x)
    x = @bn6.(x)
    x = LeakyReLU.(x, 0.2)
    x = Concatenate.(x, h2, axis: 3)

    x = @cv6.(x)
    x = @bn7.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cvt2.(x)
    x = @bn8.(x)
    x = LeakyReLU.(x, 0.2)
    x = Concatenate.(x, h1, axis: 3)

    x = @cv7.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv8.(x)
    x = Tanh.(x)
    x
  end
end

class Discriminator < Model
  def initialize(input_shape, base_num_filters)
    super()
    @input_shape = input_shape
    @cv1 = Conv2D.new(base_num_filters, 4, padding: true)
    @cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
    @cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
    @cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
    @d1 = Dense.new(1024)
    @d2 = Dense.new(1)
    @bn1 = BatchNormalization.new
    @bn2 = BatchNormalization.new
    @bn3 = BatchNormalization.new
    @bn4 = BatchNormalization.new
  end

  def forward(x)
    x = InputLayer.new(@input_shape).(x)
    x = @cv1.(x)
    x = @bn1.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv2.(x)
    x = @bn2.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv3.(x)
    x = @bn3.(x)
    x = LeakyReLU.(x, 0.2)

    x = @cv4.(x)
    x = @bn4.(x)
    x = LeakyReLU.(x, 0.2)

    x = Flatten.(x)
    x = @d1.(x)
    x = LeakyReLU.(x, 0.2)

    x = @d2.(x)
    x
  end

  # Discriminatorの学習を許可する。
  def enable_training
    trainable_layers.each do |layer|
      layer.trainable = true
    end
  end

  # Discriminatorの学習を禁止する。
  def disable_training
    trainable_layers.each do |layer|
      layer.trainable = false
    end
  end
end

class DCGAN < Model
  attr_reader :gen1
  attr_reader :gen2
  attr_reader :dis

  def initialize(gen1, gen2, dis)
    super()
    @gen1 = gen1
    @gen2 = gen2
    @dis = dis
  end

  def forward(input)
    images = @gen1.(input)
    @dis.disable_training
    # 変換した画像に対するDiscriminatorの出力結果。
    out = @dis.(images)
    # 変換した画像を元の画像に復元した画像。
    cycle_image = @gen2.(images)
    [cycle_image, out]
  end
end

# 学習したモデルを保存するためのモデル
class CycleGANModel < Model
  attr_accessor :dcgan_A
  attr_accessor :dcgan_B

  def initialize(dcgan_A, dcgan_B)
    super()
    @dcgan_A = dcgan_A
    @dcgan_B = dcgan_B
  end
end

モデルの作成

gen_A = Generator.new([64, 64, 3], 64)
gen_B = Generator.new([64, 64, 3], 64)
dis_A = Discriminator.new([64, 64, 3], 64)
dis_B = Discriminator.new([64, 64, 3], 64)
# リンゴからオレンジに変換するDCGAN-A
dcgan_A = DCGAN.new(gen_A, gen_B, dis_A)
# オレンジからリンゴに変換するDCGAN-B
dcgan_B = DCGAN.new(gen_B, gen_A, dis_B)
dis_A.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
dis_B.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
dcgan_A.setup(Adam.new(alpha: 0.0001, beta1: 0.5),
              [MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1])
dcgan_B.setup(Adam.new(alpha: 0.0001, beta1: 0.5),
              [MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1])
cycle_gan_model = CycleGANModel.new(dcgan_A, dcgan_B)

イテレータ

Cycle-GANでは、変換元画像と変換先画像がペアになっている必要はありません。
むしろペアになっていることで、過学習してしまう可能性があります。
そのため、ペアにならない変換元画像と変換先画像のミニバッチを返すイテレータを作成します。

class DNN::CycleGANIterator < DNN::Iterator
  def initialize(x_datas, y_datas, random: true, last_round_down: false)
    @x_datas = x_datas
    @y_datas = y_datas
    @random = random
    @last_round_down = last_round_down
    num_datas1 = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
    num_datas2 = y_datas.is_a?(Array) ? y_datas[0].shape[0] : y_datas.shape[0]
    if num_datas1 < num_datas2
      @num_datas = num_datas1
    else
      @num_datas = num_datas2
    end
    reset
  end

  def next_batch(batch_size)
    raise DNN::DNNError, "This iterator has not next batch. Please call reset." unless has_next?
    if @indexes1.length <= batch_size
      batch_indexes1 = @indexes1
      batch_indexes2 = @indexes2
      @has_next = false
    else
      batch_indexes1 = @indexes1.shift(batch_size)
      batch_indexes2 = @indexes2.shift(batch_size)
    end
    x_batch, _ = get_batch(batch_indexes1)
    _, y_batch = get_batch(batch_indexes2)
    # ランダムにサンプリングした変換元画像と変換先画像を返す。
    [x_batch, y_batch]
  end

  def reset
    @has_next = true
    @indexes1 = @num_datas.times.to_a
    @indexes2 = @num_datas.times.to_a
    if @random
      @indexes1.shuffle!
      @indexes2.shuffle!
    end
  end
end

学習部分

モデルの学習を行う部分です。学習させている内容をざっくりとまとめるとこんな感じです。
リンゴ => オレンジの変換の学習
①Discriminator-Aの学習: Generator-Aが出力したオレンジの画像と本物のオレンジの画像を見分けられるように学習する。
②Generator-Aの学習: Discriminator-Aを騙せるようなオレンジの画像を出力できるように学習する。
③Generator-AとGenerator-Bの学習: Generato-Aが出力したオレンジの画像をGenerator-Bに渡したとき、出力される画像が、元のリンゴの画像に戻ることを学習させる。
※オレンジ => リンゴの場合についても、上記と同様に学習させる。

iter1 = DNN::CycleGANIterator.new(x, y)
iter2 = DNN::CycleGANIterator.new(x, y)
iter3 = DNN::CycleGANIterator.new(x, y)
iter4 = DNN::CycleGANIterator.new(x, y)
num_batchs = iter1.num_datas / batch_size
real = Numo::SFloat.ones(batch_size, 1)
fake = Numo::SFloat.zeros(batch_size, 1)

(initial_epoch..epochs).each do |epoch|
  num_batchs.times do |index|
    x_batch, y_batch = iter1.next_batch(batch_size)
    x_batch2, y_batch2 = iter2.next_batch(batch_size)
    x_batch3, y_batch3 = iter3.next_batch(batch_size)
    x_batch4, y_batch4 = iter4.next_batch(batch_size)

    # DCGAN-Aの学習
    images_A = gen_A.predict(x_batch)
    dis_A.enable_training
    dis_loss = dis_A.train_on_batch(y_batch, real)
    dis_loss += dis_A.train_on_batch(images_A, fake)

    dcgan_loss = dcgan_A.train_on_batch(x_batch2, [x_batch2, real])

    puts "A epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"

    # DCGAN-Bの学習
    images_B = gen_B.predict(y_batch3)
    dis_B.enable_training
    dis_loss = dis_B.train_on_batch(x_batch3, real)
    dis_loss += dis_B.train_on_batch(images_B, fake)

    dcgan_loss = dcgan_B.train_on_batch(y_batch4, [y_batch4, real])

    puts "B epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
  end
  if epoch % 5 == 0
    cycle_gan_model.save("trained/cycle_gan_model_epoch#{epoch}.marshal")
  end
  iter1.reset
  iter2.reset
  iter3.reset
  iter4.reset
end

モデルの実行

今回も学習済みモデルのGeneratorを用意しています。
Githubからリポジトリをクローンして、「imgen.rb」を実行することで、画像を生成することができます。

$ ruby imgen.rb

実行結果

生成した画像のうち、比較的うまくいったものを貼っておきます。
inputが入力した画像で、outputが変換した画像になります。
上2行が、リンゴをオレンジに変換したもので、下2行は、オレンジをリンゴに変換したものです。

オレンジ色のリンゴと赤色のオレンジができただけなので、Cycle-GANに成功したとは言えそうにないですが、CPUでやったにしてはうまくいったんじゃないでしょうか。
cycle-gan.PNG

おわりに

Pix2pixが動いたので、もしかしたらCycle-GANもいけるんじゃね?っていう軽い気持ちでやってみましたが、思ってたよりはうまく動いてくれました。
いつかは、ruby-dnnをGPUに対応させて、ウマをシマウマに変換させられるようになりたいですね!


  1. Pix2pixモデルを使いまわしているので、InstanceNormalizationを使っていないなど、実際のCycle-GANとはいくつか異なる点があります。 

7
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
2