こんにちは。VASILYデータチームの後藤です。
この記事はVASILY Advent Calendar 2017 15日目の記事になります。
本記事では、GANによる服の着せ替えを試みます。
左列のモデル着用画像に、右列の商品画像を与え、中列の着せ替え結果を得るという課題設定です。
https://arxiv.org/abs/1709.04695, Figure.4の一部
この精度での着せ替えが可能になれば、自分の画像に商品を着せて着用イメージを得ることができ、EC上での服の買い物に有用であると考えられます。
参考
今回は以下の論文を参考にしました。
The Conditional Analogy GAN: Swapping Fashion Articles on People Images
Nikolay Jetchev, Urs Bergmann, 2017
https://arxiv.org/abs/1709.04695
論文では、Conditional Analogy GANというネットワークを提案し、上記の例のようにトップスの画像データで着せ替えを実現しています。
データセット
論文中ではトップスの画像を扱っていましたが、同じデータセットでは面白くないため、今回は別のカテゴリから選びます。様々なカテゴリを調査した結果、水着カテゴリが商品画像とモデル着用イメージの組が特に得やすいことが判明したので、水着で挑戦します。トップスに比べて商品がモデル着用画像の全体に占める面積が小さいことや、モデルさんの肌の面積が大きく、色も様々である点はトップスと異なる点になりそうです。
IQONの商品データベースから水着の商品画像とモデル着用イメージの画像の組を集めます。
以下では水着の商品画像をArticle、モデル着用イメージをHumanと呼びます。
Article
Human
画像を集める際には、以下の点に注意しました。
商品がパンツのみのものは除外する
- 商品画像がパンツ部分のみにも関わらず、モデルはブラトップも身につけているという画像の組が含まれていました
- パンツの画像のみからHumanのブラトップ部分を生成するのは困難だと判断し除外しました
Humanが背中を向けているものは除外する
- 背中の見せ方が売りの水着も多く、モデルが背中を向けている画像が一定数あります
- 商品の表の画像から背面を推定するのは困難と判断し除外しました
Humanが水着以外にシャツやハーフパンツを身につけている画像は除外する
- これらの商品はArticleに含まれないためノイズになり得るため除外しました
重複商品を除外する
- Generatorに渡すデータとして、Humanに対するpositive ArticleとNegative Articleのトリプレットを構成するため、重複する画像は除外しておきます
- 今回はPhotos Duplicate Cleanerというツールを利用し、重複画像を除外しました。
最終的に ArticleとHumanを940組集めることができました。
論文では15000組のデータを用意しているため、データが不十分であることが考えられます。
ネットワーク
訓練するモデルはGeneratorとDiscriminatorの2つです。
通常のGANと異なるところは、GeneratorがHuman, Positive Article, Negative Articleの3枚の画像を入力し、4 channel の画像一枚を出力する点です。4 channelのうち1channelはfilterの役割を果たします。
ミニバッチの中に本物の画像と生成画像が含むため、NormalizationにはBatch Normalizationではなく、Instance Normalizationを利用しています。公式にはないため、実装はcrcrparさんのものを利用させていただきました。[crcrpar/instance_normalization_chainer]
(https://github.com/crcrpar/instance_normalization_chainer)
import numpy as np
import chainer
from chainer import cuda, link, initializers
import chainer.functions as F
import chainer.links as L
from instance_normalization import InstanceNormalization
def add_noise(h, sigma=0.2):
xp = cuda.get_array_module(h.data)
if chainer.config.train:
return h + sigma * xp.random.randn(*h.shape)
else:
return h
class Generator(chainer.Chain):
def __init__(self, bottm_width=8, ch=512, wscale=0.02):
w = chainer.initializers.Normal(wscale)
super(Generator, self).__init__()
with self.init_scope():
self.c0 = L.Convolution2D(9, ch // 8, 4, 2, 1, initialW=w)
self.c1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
self.c2 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
self.c3 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
self.bn0 = InstanceNormalization(ch // 8)
self.bn1 = InstanceNormalization(ch // 4)
self.bn2 = InstanceNormalization(ch // 2)
self.bn3 = InstanceNormalization(ch // 1)
self.dc0 = L.Deconvolution2D(ch // 1, ch // 2, 4, 2, 1, initialW=w)
self.dc1 = L.Deconvolution2D(ch // 2, ch // 4, 4, 2, 1, initialW=w)
self.dc2 = L.Deconvolution2D(ch // 4, ch // 8, 4, 2, 1, initialW=w)
self.dc3 = L.Deconvolution2D(ch // 8, 4, 4, 2, 1, initialW=w)
self.dbn0 = InstanceNormalization(ch // 2)
self.dbn1 = InstanceNormalization(ch // 4)
self.dbn2 = InstanceNormalization(ch // 8)
def __call__(self, x):
h1 = F.relu(self.c0(x))
h2 = F.relu(self.bn1(self.c1(h1)))
h3 = F.relu(self.bn2(self.c2(h2)))
h = F.relu(self.bn3(self.c3(h3)))
h = F.relu(self.dbn0(self.dc0(h))) + h3
h = F.relu(self.dbn1(self.dc1(h))) + h2
h = F.relu(self.dbn2(self.dc2(h))) + h1
x = F.sigmoid(self.dc3(h))
return x
class Discriminator(chainer.Chain):
def __init__(self, bottom_width=8, ch=512, wscale=0.02):
w = chainer.initializers.Normal(wscale)
super(Discriminator, self).__init__()
with self.init_scope():
self.c0 = L.Convolution2D(6, ch // 8, 4, 2, 1, initialW=w)
self.c1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
self.c2 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
self.c3 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
self.c4 = L.Convolution2D(ch // 1, ch // 1, 4, 2, 1, initialW=w)
self.l4 = L.Linear(bottom_width * bottom_width * ch, 1, initialW=w)
self.bn0 = InstanceNormalization(ch // 8)
self.bn1 = InstanceNormalization(ch // 4)
self.bn2 = InstanceNormalization(ch // 2)
self.bn3 = InstanceNormalization(ch // 1)
self.bn4 = InstanceNormalization(ch // 1)
def __call__(self, x):
h = add_noise(x)
h = F.leaky_relu(add_noise((self.c0(h))))
h = F.leaky_relu(add_noise(self.bn1(self.c1(h))))
h = F.leaky_relu(add_noise(self.bn2(self.c2(h))))
h = F.leaky_relu(add_noise(self.bn3(self.c3(h))))
h = F.leaky_relu(add_noise(self.bn4(self.c4(h))))
return F.sigmoid(self.l4(h))
損失関数
上記で定義したGenerator(G)とDiscriminator(D)を最適化するために、3つの損失関数を定義して学習に利用します。
\min_{G} \max_{D} \mathcal{L}_{cGAN}(G,D)+ \gamma_i \mathcal{L}_{id}(G) +\gamma_c\mathcal{L}_{cyc}(G)
Adversarial Loss
\mathcal{L}_{cGAN}(G,D) = \mathbb{E}_{x_i,y_i \sim p_{\mathrm{data}}} \sum_{\lambda,\mu} \left[ \log D_{\lambda,\mu}(x_i,y_i) \right] \nonumber \\
+ \mathbb{E}_{x_i,y_i,y_j \sim p_{\mathrm{data}}}\sum_{\lambda,\mu} \left[ \left(1 - \log D_{\lambda,\mu}(G(x_i,y_i,y_j),y_j) \right) \right] \nonumber \\
+ \mathbb{E}_{x_i,y_{j \neq i} \sim p_{\mathrm{data}}}\sum_{\lambda,\mu} \left[ \left(1 - \log D_{\lambda,\mu}(x_i,y_j) \right) \right]
通常のGANで用いられる損失関数と似ています。DiscriminatorがHumanとArticleの組を受け取った際、本物のHumanとArticleの組であるか、生成されたHumanと着せ替え対象のArticleの組なのかを見抜くように設計します。また、本物の画像の組であっても、HumanとArticleが対応関係にない組も偽物であると判断させるよう設計します。
Cycle Loss
\mathcal{L}_{cyc}(G) = \mathbb{E}_{x_i,y_i,y_j \sim p_{\mathrm{data}}} \| x_i - G(G(x_i,y_i,y_j),y_j,y_i) \|
一度着せ替えたHumanをもう一度Generatorで元のArticleに着せ替えた結果が、元のHuman画像とどれくらい近いかを評価します。
Id Loss
\mathcal{L}_{id}(G) = \mathbb{E}_{x_i,y_i,y_j \sim p_{\mathrm{data}}} \|\alpha_i^j\|
Generatorの出力の一つであるフィルターができるだけ小さな範囲に収まるように設計された損失関数です。Human画像の服の部分以外はできるだけ元の画像を保持するという課題設定のため、この損失関数が必要になります。
学習
ミニバッチのサイズが16での学習を50エポック行いました。最適化手法は論文の設定と同じAdamを使用しました。3つの損失関数のバランスは論文の通り、adversarial、cycle、idをそれぞれ1:1:0.1の割合で使用しました。50epoch程度回すと、Generatorが壊れはじめ、おかしな画像を出力するようになりました。
データオーギュメンテーションは、左右反転のみ行っています。左右非対称の水着や模様があるため、左右反転を行う際はHumanとArticleは必ずセットで行っています。
結果
Generatorを使って着せ替えを行った例を紹介します。
以下の結果は、上段のHumanに左のArticleを着せ替えた例です。
ビキニタイプからワンピースへの形の上での着せ替えや、色合いの変化はある程度できていますが、細かな模様が再現できるほどにはGeneratorが発達してないようです。元の水着の色が残っていたり、足の部分を塗り替えようとしていたりとフィルタの学習も上手くいっていないようです。
まとめ
CAGANの論文を参考に、水着の着せ替えを試みました。今回の実験では、ビキニからワンピースや色の変化をさせることはできましたが、論文ほどには綺麗に着せ替えを行うことはできませんでした。また、今回試した範囲ではGeneratorの学習が上手く進みませんでした。データの量が圧倒的に足らない可能性や、そもそもの実装が間違っている可能性もあるため、今後も検証を続けていきたいと思います。