はじめに
前回は物体検出についてでした。
今回は GAN(Generative Adversarial Network) を使った画像の補完です。
機械学習の初級編では、分かりやすい教師あり学習に偏りがちだと思いますが、初級編が終わった(つもり)ところで教師なしについて振り返り、復習と調査を進める中でGANに出会いました。
GANは半教師ありではありますが。
教師あり/なしとも違う強化学習にも出会ったのですが、GANのほうがわかりやすそうであったことと見た目のインパクト(画像生成とか、ぱっと見派手じゃないですか)でGANを選択。。
そして、GANを調査するうちにGLCIC(Globally and Locally Consistent Image Completion)を発見。
GLCIC概要
このデモ動画にやられました。
論文執筆者の飯塚さんが投稿した短縮版はこちらです。
画像の一部分を消すとそれっぽく補完される。写真に写り込んだアレやコレやを無かったことに出来る!
というわけでGLCICです。
初学者のまとめであり、誤りも多くあると思いますので、コメントにてご指摘頂けると大変ありがたいです。
GANの概要
GLCICの前にそのベースとなるGANから。
GANは「敵対的生成ネットワーク」と訳されるように、敵対する2つの(2種類の)ネットワークが競い合いながら学習を進めていく手法です。
2つのネットワークとは、「Generator」と「Discriminator」で、このようなネットワーク構成になっています。
引用元: GAN: A Beginner’s Guide to Generative Adversarial Networks - Deeplearning4j
- Generator
「生成器」であり、何かしらを生成するネットワークで、画像生成を用いる事例が多い。(その事例例しか知りません)
本物に近いものを生成することが目的であり、Discriminatorにも見破れない本物らしいものを生成することが目標。 - Discriminator
「識別器」。
本物と偽物を正しく識別することが目標であり、Generatorが生成したものは偽物と識別することが目標。
GeneratorとDiscriminatorの関係について、以下の例えがとてもわかり易かったです。
この関係は紙幣の偽造者と警察の関係によく例えられます。偽造者は本物の紙幣とできるだけ似ている偽造紙幣を造ります。警察は本物の紙幣と偽造紙幣を見分けようとします。
次第に警察の能力が上がり、本物の紙幣と偽造紙幣をうまく見分けられるようになったとします。すると偽造者は偽造紙幣を使えなくなってしまうため、更に本物に近い偽造紙幣を造るようになります。警察は本物と偽造紙幣を見分けられるようにさらに改善し…という風に繰り返していくと、最終的には偽造者は本物と区別が付かない偽造紙幣を製造できるようになるでしょう。
引用元: はじめてのGAN
GeneratorとDiscriminatorを畳み込み層を利用したニューラルネットワークで構成したGANをDCGAN(Deep Convlutional GAN)と呼び、今回取り上げるGLCICはこのDCGANを発展させた手法になります。(と理解しました。)
GLCICの概要
画像に空けた穴を画像全体の雰囲気に合わせながら自然な形で補完する手法です。
...微妙な一言でしかまとめられないので、本家の説明をどうぞ。
本研究では,畳み込みニューラルネットワークを用いて,シーンの大域的かつ局所的な整合性を考慮した画像補完を行う手法を提案する.提案する補完ネットワークは全層が畳み込み層で構成され,任意のサイズの画像における自由な形状の「穴」を補完できる.この補完ネットワークに,シーンの整合性を考慮した画像補完を学習させるため,本物の画像と補完された画像を識別するための大域識別ネットワークと局所識別ネットワークを構築する.大域識別ネットワークは画像全体が自然な画像になっているかを評価し,局所識別ネットワークは補完領域周辺のより詳細な整合性によって画像を評価する.この2つの識別ネットワーク両方を「だます」ように補完ネットワークを学習させることで,シーン全体で整合性が取れており,かつ局所的にも自然な補完画像を出力することができる.提案手法により,様々なシーンにおいて自然な画像補完が可能となり,さらに従来のパッチベースの手法ではできなかった,入力画像に写っていないテクスチャや物体を新たに生成することもできる.これにより,人間の顔の一部を補完するような,複雑な画像補完を実現した.
引用元: GLCIC概要
既存手法との違い
既存の手法
- Diffusion Based
穴の周りの情報を活用して穴を埋める手法のようです。
小さなキズや穴の補完には対応できるが、大きな穴には対応出来ません。 - Patch Based
画像中の情報を活用して穴を埋める手法のようです。
diffusion basedよりも大きな穴を補完出来ますが、入力画像を利用しているので、その画像内に無いオブジェクトを生成することは出来ません。 - Contxet Encoder
2016提案のCNN+GANを用いた手法です。
Generatorは穴が空いた画像を元に、穴を補完する部分の画像を出力し、Discriminatorはその補完部分のみを判定します。
入力画像のサイズは固定のようです。
既存手法からの改善点
機能的な比較は以下の通りで、Context Encoderの発展系という位置付けになると思います。
Patch Based | Context Encoder | GLCIC | |
---|---|---|---|
入力画像サイズ | 任意 | 固定 | 任意 |
穴周辺の局所的な整合性 | 可 | 不可 | 可 |
入力画像の意味を理解した補正 | 不可 | 可 | 可 |
入力画像に無い新しいオブジェクトの生成 | 不可 | 可 | 可 |
以下、既存手法との出力結果の比較です。
「Pathak et al. 2016」がContext Encoder、「Ours」がGLCICです。
これをみると、既存手法よりもかなり自然に補完されていることがわかります。
引用元: GLCIC論文
ネットワーク構造
引用元: [GLCIC論文]Context Encoderとの大きな違いとしては、以下になるかと思います。
- Generator
- 穴部分だけでなく画像全体を出力している。
- Generatorの出力画像の穴補完部分と入力画像をマージした画像をDiscriminatorの入力としている。
- Discriminator
- Global DiscriminatorとLocal Discriminatorの2つのネットワークで構成される。
以降で、GeneratorとDiscriminatorそれぞれのポイントについて、自分なりの感覚的な理解をまとめてみます。
数式など含めた詳細は参考資料を確認してください。
Generator
拡張畳み込み(Dilated Convolution)
通常の畳み込みはフィルタをそのまま適用しますが、拡張畳み込みではフィルタを畳み込み対象に対して間隔を空けて適用します。。。
うまく説明出来ていないと思いますので、こちらをどうぞ。
引用元: A technical report on convolution arithmetic in the context of deep learning
拡張畳み込みの何が嬉しいかというと、受容野(畳み込み後の特徴マップピクセル1つが関わる畳み込み前の特徴マップのピクセル)が簡単に増やせるから。
普通の畳み込みで受容野を増やそうと思うと、フィルタのサイズを大きくするか層を増やすことになりますが、フィルタサイズをあまりに大きくすると層を局所的な特徴を掴みづらそうですし、単純に層を増やすと勾配消失とか学習コストが高くなるとかの問題がありそうです。
その他、通常の畳込みとの比較については以下を見てください。
参考: Dilated Convolution - ジョイジョイジョイ
ネットワークの中間部分にこの拡張畳み込みを適用し、受容野を広げつつも解像度を落とさないことで、大域的な整合性を確保しているものと理解してみました。
mask領域
学習時は穴周辺の画像も含めてmaskした(切り抜いた)画像をGeneratorの入力とすること。
以下の図の左側のP2のように、maskが穴周辺の領域を含まない場合は穴を補完する情報が求められない。
なので、右側のように、maskは穴よりも大きくする必要がある。
引用元: GLCIC論文
ネットワーク構造
引用元: GLCIC論文
学習安定化のため、出力層以外ではBatch Normalizationを用い、活性化関数にはReluを用います。
出力層では、活性化関数にSigmoidを用います。
と論文にはこのように記載されていましたが、GAN安定化のテクニックと異なるのはなぜでしょうか。
GAN(Generative Adversarial Networks)を学習させる際の14のテクニック - Qiita
私の実装では、GAN安定化のテクニックのうち、活性化関数にLeakyRelu、tanhを用いたほうが結果が良い感じがしました。
私の結果の精度が悪すぎるので参考にならないかもしれませんが。。。
この点はまた改めて検証してみたいと思います。
Discriminator
Global Discriminator、Local Discriminator
Global Discriminatorで大域的な(画像全体の)整合性を判断し、Local Discriminatorで局所的な整合性を判断します。
具体的には、Global Discriminatorには、Generatorが生成した画像のmask部分とGeneratorへの入力画像をマージした画像を入力します。
Local Discriminatorには、Generatorが生成した画像のmask部分のみを入力します。
両ネットワークの評価を総合的に判断します。(といっても損失を加算するのみ)
ネットワーク構造
引用元: GLCIC論文
学習安定化のため、出力層以外ではBatch Normalizationを用い、活性化関数にはReluを用います。
出力層では、活性化関数にSigmoidを用います。
損失
- Generator
ピクセルごとの平均二乗誤差 - Discriminator
GAN損失
D(x,Md): 入力画像の判断結果。正しく判断できれば結果は1(本物)に近いはずなので、logを取ると0に近くなるはず。
D(C(x,Mc), Mc): Generatorが生成した画像の判断結果。正しく判断できれば結果は0(偽物)に近いはずなので、logを取ると1に近くなるはず。
学習
学習方法
3stageに分けて学習します。
- stage1
Generatorのみ鍛える。 - stage2
Generatorの重みは固定して、Discriminatorのみ鍛える。 - stage3
両方並行して鍛える。
学習環境
- データセット
Places2
http://places2.csail.mit.edu/download.html - マシン
4台のK80 GPUを搭載した1台のマシン - ミニバッチサイズ
96画像 - 学習量
- stage1: 90,000反復
- stage2: 10,000反復
- stage3: 400,000反復?(合計50万回とあったので)
- 所要期間
2か月!!
1反復当たりの画像数ですが、以下の「Data of Places365-Standard」で180万枚なので、膨大な回数の学習を行ったことになります。
http://places2.csail.mit.edu/download.html
私の行った学習量では足元にも及ばず。
同じ量学習したらクラウド破産してしまいます。
他
上記のような学習をこなした鍛え上げられたGeneratorの出力であっても、穴周辺との色合いが微妙にズレる事があるようです。
その対策として、出力画像に対して以下の fast marching method → Poisson image blending の順で処理を加えます。
私の実装では、この後処理は入れていません。入れるほどの精度が出ていないので。。。
検証
実装
こちらです。
https://github.com/shtamura/glcic
論文と異なる点は以下です。
- 入力画像サイズを256*256にリサイズ
- 穴は黒(0)で埋める
- 以下に則り、LeakyRelu、tanh、Adamを利用
GAN(Generative Adversarial Networks)を学習させる際の14のテクニック - Qiita - 「他」に記載した後処理は未実装。
こちらの実装を参考にさせて頂きました。
https://github.com/tadax/glcic
学習環境
- データセット
Places2
Small images (256 * 256) with easy directory structure - マシン
1台のK80 GPUを搭載した1台のマシン - ミニバッチサイズ
16画像 - 学習量
- stage1: 16万枚(反復 ではありません!単なる枚数です)
- stage2: 16万枚
- stage3: 16万枚
- 所要期間
2日。。。
結果
前回と同様に芳しくありません。
それっぽい何かが生成されるのみで、とても「無かったこと」になるような出来ではありません。
原因は学習量不足なのか、それ以外に原因があるのか、その判別は出来ていません。
それを見極める程にGPUを振り回す予算を捻出出来ず。。。
比較的まともな結果だけ以下に載せます。
いちおう、それっぽい影が生成されています。
これら以外の結果はこちらにあります。
https://github.com/shtamura/glcic
最後に
また長くなってしまいました。さらに、結果がかなり微妙。
ここまで駄文にお付き合い頂き、ありがとうございました。
#参考資料
- GLCIC概要
- GLCIC論文
- Context Encorder論文
- はじめてのGAN
- A technical report on convolution arithmetic in the context of deep learning
- Dilated Convolution - ジョイジョイジョイ
- [GAN(Generative Adversarial Networks)を学習させる際の14のテクニック - Qiita](https://qiita.com/underfitting/items/a0cbb035568dea33b2d7)
- 参考にした実装