はじめに
今回はruby-dnnでCycle-GANを動かしてリンゴをオレンジに変換してみたいと思います。
RubyとCPUでCycle-GANという時点でもう無理ゲー感半端ないです😇
コード全文はhttps://github.com/unagiootoro/apple2orange-cycleganにあります。
使用ライブラリ/バージョンなど
Ruby ... v2.6.5
ruby-dnn ... v1.1.4
Numo::NArray ... v0.9.1.5
Numo::Linalg ... v0.1.4
rubyzip ... v2.2.0
Cycle-GANざっくり解説
Cycle-GANとは、二つのドメインの異なる画像間でそれぞれ相互に変換ができるようにするモデルです。
画像の変換というと、前回やったPix2pixも画像変換ですが、Pix2pixは入力画像に対する出力画像が一対一でなければならないのに対して、Cycle-GANでは、一対一のペアがなくても変換元と変換先の画像さえあれば変換できるのが特徴です。
Cycle-GANでは、「DCGAN-A」「DCGAN-B」「Cycle Consistency Loss」の3つによって変換が行われます。イメージとしては、以下の図のような感じです。
(画像はペイントで作った手抜きです🤡)
DCGAN-Aでは、リンゴがオレンジになることを学習させます。
逆にDCGAN-Bでは、オレンジがリンゴになることを学習させます。
しかし、これだけだと、リンゴの画像をオレンジに変換したとき、元のリンゴの画像の形状を保つことができないので、Cycle Consistency Lossが必要になります。
Cycle Consistency Loss
Cycle Consistency Lossは、DCGAN-Aでリンゴをオレンジに変換した画像にDCGAN-Bを適用すると元のリンゴに戻ることを学習させるために使用します。同じように、DCGAN-Bでオレンジをリンゴに変換した画像にDCGAN-Aを適用すると元のオレンジに戻ることも学習させます。
これによって、元のリンゴの画像の形状を保ったまま必要な箇所だけオレンジに変換することができるようになります。
ruby-dnnでCycle-GAN
モデル定義
Cycle-GANのモデル定義です。ほぼ前回やったPix2pixのモデルを使いまわしています。 1
各クラスの役割は、以下の通りです。
Generator: 入力画像から変換先の画像を生成します。
Discriminator: Generatorが生成した画像を受け取り、それが本物か生成された画像かを判断できるように学習させるためのモデルです。
DCGAN: Discriminatorを騙せるような画像を生成できるようにGenerator-Aを学習させるためのモデルです。また、Generator-Aが生成した画像にGenerator-Bを適用したとき、元の画像に戻るように学習させます。
長いので折りたたんでいます
class Generator < Model
def initialize(input_shape, base_num_filters)
super()
@input_shape = input_shape
@cv1 = Conv2D.new(base_num_filters, 4, padding: true)
@cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
@cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
@cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
@cv5 = Conv2D.new(base_num_filters * 2, 4, padding: true)
@cv6 = Conv2D.new(base_num_filters, 4, padding: true)
@cv7 = Conv2D.new(base_num_filters, 4, padding: true)
@cv8 = Conv2D.new(3, 4, padding: true)
@cvt1 = Conv2DTranspose.new(base_num_filters * 2, 4, strides: 2, padding: true)
@cvt2 = Conv2DTranspose.new(base_num_filters, 4, strides: 2, padding: true)
@bn1 = BatchNormalization.new
@bn2 = BatchNormalization.new
@bn3 = BatchNormalization.new
@bn4 = BatchNormalization.new
@bn5 = BatchNormalization.new
@bn6 = BatchNormalization.new
@bn7 = BatchNormalization.new
@bn8 = BatchNormalization.new
end
def forward(x)
input = InputLayer.new(@input_shape).(x)
x = @cv1.(input)
x = @bn1.(x)
h1 = LeakyReLU.(x, 0.2)
x = @cv2.(h1)
x = @bn2.(x)
x = LeakyReLU.(x, 0.2)
x = @cv3.(x)
x = @bn3.(x)
h2 = LeakyReLU.(x, 0.2)
x = @cv4.(h2)
x = @bn4.(x)
x = LeakyReLU.(x, 0.2)
x = @cv5.(x)
x = @bn5.(x)
x = LeakyReLU.(x, 0.2)
x = @cvt1.(x)
x = @bn6.(x)
x = LeakyReLU.(x, 0.2)
x = Concatenate.(x, h2, axis: 3)
x = @cv6.(x)
x = @bn7.(x)
x = LeakyReLU.(x, 0.2)
x = @cvt2.(x)
x = @bn8.(x)
x = LeakyReLU.(x, 0.2)
x = Concatenate.(x, h1, axis: 3)
x = @cv7.(x)
x = LeakyReLU.(x, 0.2)
x = @cv8.(x)
x = Tanh.(x)
x
end
end
class Discriminator < Model
def initialize(input_shape, base_num_filters)
super()
@input_shape = input_shape
@cv1 = Conv2D.new(base_num_filters, 4, padding: true)
@cv2 = Conv2D.new(base_num_filters, 4, strides: 2, padding: true)
@cv3 = Conv2D.new(base_num_filters * 2, 4, padding: true)
@cv4 = Conv2D.new(base_num_filters * 2, 4, strides: 2, padding: true)
@d1 = Dense.new(1024)
@d2 = Dense.new(1)
@bn1 = BatchNormalization.new
@bn2 = BatchNormalization.new
@bn3 = BatchNormalization.new
@bn4 = BatchNormalization.new
end
def forward(x)
x = InputLayer.new(@input_shape).(x)
x = @cv1.(x)
x = @bn1.(x)
x = LeakyReLU.(x, 0.2)
x = @cv2.(x)
x = @bn2.(x)
x = LeakyReLU.(x, 0.2)
x = @cv3.(x)
x = @bn3.(x)
x = LeakyReLU.(x, 0.2)
x = @cv4.(x)
x = @bn4.(x)
x = LeakyReLU.(x, 0.2)
x = Flatten.(x)
x = @d1.(x)
x = LeakyReLU.(x, 0.2)
x = @d2.(x)
x
end
# Discriminatorの学習を許可する。
def enable_training
trainable_layers.each do |layer|
layer.trainable = true
end
end
# Discriminatorの学習を禁止する。
def disable_training
trainable_layers.each do |layer|
layer.trainable = false
end
end
end
class DCGAN < Model
attr_reader :gen1
attr_reader :gen2
attr_reader :dis
def initialize(gen1, gen2, dis)
super()
@gen1 = gen1
@gen2 = gen2
@dis = dis
end
def forward(input)
images = @gen1.(input)
@dis.disable_training
# 変換した画像に対するDiscriminatorの出力結果。
out = @dis.(images)
# 変換した画像を元の画像に復元した画像。
cycle_image = @gen2.(images)
[cycle_image, out]
end
end
# 学習したモデルを保存するためのモデル
class CycleGANModel < Model
attr_accessor :dcgan_A
attr_accessor :dcgan_B
def initialize(dcgan_A, dcgan_B)
super()
@dcgan_A = dcgan_A
@dcgan_B = dcgan_B
end
end
モデルの作成
gen_A = Generator.new([64, 64, 3], 64)
gen_B = Generator.new([64, 64, 3], 64)
dis_A = Discriminator.new([64, 64, 3], 64)
dis_B = Discriminator.new([64, 64, 3], 64)
# リンゴからオレンジに変換するDCGAN-A
dcgan_A = DCGAN.new(gen_A, gen_B, dis_A)
# オレンジからリンゴに変換するDCGAN-B
dcgan_B = DCGAN.new(gen_B, gen_A, dis_B)
dis_A.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
dis_B.setup(Adam.new(alpha: 0.00001, beta1: 0.1), SigmoidCrossEntropy.new)
dcgan_A.setup(Adam.new(alpha: 0.0001, beta1: 0.5),
[MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1])
dcgan_B.setup(Adam.new(alpha: 0.0001, beta1: 0.5),
[MeanAbsoluteError.new, SigmoidCrossEntropy.new], loss_weights: [10, 1])
cycle_gan_model = CycleGANModel.new(dcgan_A, dcgan_B)
イテレータ
Cycle-GANでは、変換元画像と変換先画像がペアになっている必要はありません。
むしろペアになっていることで、過学習してしまう可能性があります。
そのため、ペアにならない変換元画像と変換先画像のミニバッチを返すイテレータを作成します。
class DNN::CycleGANIterator < DNN::Iterator
def initialize(x_datas, y_datas, random: true, last_round_down: false)
@x_datas = x_datas
@y_datas = y_datas
@random = random
@last_round_down = last_round_down
num_datas1 = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
num_datas2 = y_datas.is_a?(Array) ? y_datas[0].shape[0] : y_datas.shape[0]
if num_datas1 < num_datas2
@num_datas = num_datas1
else
@num_datas = num_datas2
end
reset
end
def next_batch(batch_size)
raise DNN::DNNError, "This iterator has not next batch. Please call reset." unless has_next?
if @indexes1.length <= batch_size
batch_indexes1 = @indexes1
batch_indexes2 = @indexes2
@has_next = false
else
batch_indexes1 = @indexes1.shift(batch_size)
batch_indexes2 = @indexes2.shift(batch_size)
end
x_batch, _ = get_batch(batch_indexes1)
_, y_batch = get_batch(batch_indexes2)
# ランダムにサンプリングした変換元画像と変換先画像を返す。
[x_batch, y_batch]
end
def reset
@has_next = true
@indexes1 = @num_datas.times.to_a
@indexes2 = @num_datas.times.to_a
if @random
@indexes1.shuffle!
@indexes2.shuffle!
end
end
end
学習部分
モデルの学習を行う部分です。学習させている内容をざっくりとまとめるとこんな感じです。
リンゴ => オレンジの変換の学習
①Discriminator-Aの学習: Generator-Aが出力したオレンジの画像と本物のオレンジの画像を見分けられるように学習する。
②Generator-Aの学習: Discriminator-Aを騙せるようなオレンジの画像を出力できるように学習する。
③Generator-AとGenerator-Bの学習: Generato-Aが出力したオレンジの画像をGenerator-Bに渡したとき、出力される画像が、元のリンゴの画像に戻ることを学習させる。
※オレンジ => リンゴの場合についても、上記と同様に学習させる。
iter1 = DNN::CycleGANIterator.new(x, y)
iter2 = DNN::CycleGANIterator.new(x, y)
iter3 = DNN::CycleGANIterator.new(x, y)
iter4 = DNN::CycleGANIterator.new(x, y)
num_batchs = iter1.num_datas / batch_size
real = Numo::SFloat.ones(batch_size, 1)
fake = Numo::SFloat.zeros(batch_size, 1)
(initial_epoch..epochs).each do |epoch|
num_batchs.times do |index|
x_batch, y_batch = iter1.next_batch(batch_size)
x_batch2, y_batch2 = iter2.next_batch(batch_size)
x_batch3, y_batch3 = iter3.next_batch(batch_size)
x_batch4, y_batch4 = iter4.next_batch(batch_size)
# DCGAN-Aの学習
images_A = gen_A.predict(x_batch)
dis_A.enable_training
dis_loss = dis_A.train_on_batch(y_batch, real)
dis_loss += dis_A.train_on_batch(images_A, fake)
dcgan_loss = dcgan_A.train_on_batch(x_batch2, [x_batch2, real])
puts "A epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
# DCGAN-Bの学習
images_B = gen_B.predict(y_batch3)
dis_B.enable_training
dis_loss = dis_B.train_on_batch(x_batch3, real)
dis_loss += dis_B.train_on_batch(images_B, fake)
dcgan_loss = dcgan_B.train_on_batch(y_batch4, [y_batch4, real])
puts "B epoch: #{epoch}, index: #{index}, dis_loss: #{dis_loss}, dcgan_loss: #{dcgan_loss}"
end
if epoch % 5 == 0
cycle_gan_model.save("trained/cycle_gan_model_epoch#{epoch}.marshal")
end
iter1.reset
iter2.reset
iter3.reset
iter4.reset
end
モデルの実行
今回も学習済みモデルのGeneratorを用意しています。
Githubからリポジトリをクローンして、「imgen.rb」を実行することで、画像を生成することができます。
$ ruby imgen.rb
実行結果
生成した画像のうち、比較的うまくいったものを貼っておきます。
inputが入力した画像で、outputが変換した画像になります。
上2行が、リンゴをオレンジに変換したもので、下2行は、オレンジをリンゴに変換したものです。
オレンジ色のリンゴと赤色のオレンジができただけなので、Cycle-GANに成功したとは言えそうにないですが、CPUでやったにしてはうまくいったんじゃないでしょうか。
おわりに
Pix2pixが動いたので、もしかしたらCycle-GANもいけるんじゃね?っていう軽い気持ちでやってみましたが、思ってたよりはうまく動いてくれました。
いつかは、ruby-dnnをGPUに対応させて、ウマをシマウマに変換させられるようになりたいですね!
-
Pix2pixモデルを使いまわしているので、InstanceNormalizationを使っていないなど、実際のCycle-GANとはいくつか異なる点があります。 ↩