動機
機械学習関連ではありがちなテーマですが、白黒写真をカラー化してみたいと思います。GitHubにコードもあげて置きました。[GitHub]
もしもこの記事がお役に立てた時は是非Qiitaのイイねボタン、もしくはGitHubのスターを頂けると励みになりますm(_ _)m
先日、自分の父の遺品を改めて見返していた際に、ふと昔の白黒写真が出てきて。亡くなった当時のまま放置していたのですが、今ならこれもカラー化できると思ったのがキッカケです。技術ブログというより、個人の趣味ブログ的な感じです。尚、GANの勉強をする際に参考にさせて頂いてる方のブログ、本があります。その方の記事は非常に勉強になることばかりで、今回の記事も内容被る部分がございます。下記にリンク貼っておきますので、是非ご覧ください。
[参考URL]
Shikoan's ML Blog
モザイク除去から学ぶ 最先端のディープラーニング
ブログの構成
- pix2pixの概要
-
今回のデータセットについて
- [Google PhotoScan](##Google PhotoScan)
- 画像サイズ
- 学習開始
- まとめ。そして肝心の写真の結果
pix2pixの概要
GAN
Generator(生成器)にDiscriminator(分類器)を組み込み、両者を並行して学習していきます。GeneratorのLossにDiscriminatorの予測値を使った動的な損失関数を加えていくイメージです。GANには色々な種類があり、定義によってその分けも様々ですが、以下のような分け方が存在します。
種類 | 概要 |
---|---|
Conditional GAN | Genaratorのinputとoutputに関係性がある |
Non-Conditional GAN | Genaratorのinputとoutputに関係性がない |
今回のpix2pixは典型的なConditional GAN(CGAN)になります。よくGANのチュートリアルで登場するDCGANはノイズデータからoutputを生成するので、Non-Conditional GANになります。
pix2pixはCGANということもあって、inputの情報が非常に重要になってきます。例えばDCGANでは、Discriminatorとの関係で出てくるAdversarial Lossを使って、学習が進みますが、pix2pixの場合、Adversarial Lossに加え、fake画像とreal画像の差(例えばL1 lossやMSE)も使用して学習が進んでいます。そうすることで、他のGANに比べ学習の進みが早く、結果も安定しやすいという特徴があります。また逆にL1-lossだけを使用して学習する様な手法の場合、どうしても損失を少なくするために全体的にぼんやりしたOutputを出力してしまったり、平均的な画素で全体をベタ塗りしてしまったりといった感じになりやすく、Adversarial Lossを入れることで、L1lossはむしろ大きくなる可能性があってもより本物っぽい画像を出力できる様に学習が進んでいく傾向にあります。知覚品質と歪みにはトレードオフがある
という原理はGANを考えていく上で非常に重要な要素の一つです。
参考: The Perception-Distortion Tradeoff (2017) [arXiv]
PatchGAN
下記、pix2pixの元論文になりますが、PatchGANについての詳細はこちらを参照するのが良いかと思います。
Image-to-Image Translation with Conditional Adversarial Networks (2016) [arXiv]
PatchGANでは、Discriminatorが画像の正誤判定を行う際に、いくつかの領域に分け、それぞれの領域で正誤判定を行う形になります。
領域を分割するといっても、実際に画像を分割し各々別個でDiscriminatorに突っ込むわけではありません。理論上はそうなのですが、実装上は1枚の画像をDiscriminatorに突っ込みその出力を2階のテンソルにします。その時、テンソルがもつ各ピクセルの値はinput画像のパッチ領域の情報を元に導き出されているため、結果的にテンソルが持つ、各ピクセルの値と実際のTrue or False(1 or 0)の間のLossをとることで、PatchGANを実現しています。言葉で説明してもよく分からないと思うので、上記のフランス国旗で例を出してみます。
fig, axes = plt.subplots(1,2)
# 画像の読み込み
img = cv2.imread('france.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (600, 300)) # オリジナルサイズ -> (300, 600, 3)
axes[0].imshow(img)
axes[0].set_xticks([])
axes[0].set_yticks([])
img = ToTensor()(img) # (300, 600, 3) -> (3, 300, 600)
img = img.view(1, 3, 300, 600) # (3, 300, 600) -> (1, 3, 300, 600)
img = nn.AvgPool2d(100)(img) # (1, 3, 300, 600) -> (1, 3, 3, 6)
img = nn.Conv2d(3, 1, kernel_size=1)(img) # (1, 3, 3, 6) -> (1, 1, 3, 6)
img = np.squeeze(img.detach().numpy())
axes[1].imshow(img, cmap='gray')
axes[1].set_xticks([])
axes[1].set_yticks([])
[結果]
上記では、inputされた画像を(3, 6)サイズの特徴マップに落とし込んでいますが、これは元の画像のパッチ領域を全て(1, 1)に圧縮してることに他なりません。上記の例では3×6=18
個のピクセルに対して真偽判定を行います。
Unet
pix2pixではGeneratorにU-netを使用します。Segmentationでも同じみのU-netはEncoder-Decoder構造に加え、Skip-Connectionを張ることで、入力データの情報をなるべく失わないよう工夫しています。下記が実験で使用したGeneratorのforwardメソッドですが、元々の情報をtorch.cat
で都度連結させていることが分かります。
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
x4 = self.enc4(x3)
out = self.dec1(x4)
out = self.dec2(torch.cat([out, x3], dim=1))
out = self.dec3(torch.cat([out, x2], dim=1))
out = self.dec4(torch.cat([out, x1], dim=1))
return out
ちなみに下記のU-netのイメージは最近twitterで話題に挙がっていたGoogleが無償で公開している機械学習関連のイメージスライドになります。[ML Visuals]
良い感じの絵が沢山あるので是非一度ご覧ください。
今回のデータセットについて
Google PhotoScan
学習データとしてMIT-Adobe FiveK Datasetを使用します。こちらはimage enhance系の論文でよく使用されるもので、加工前画像とプロの編集者による加工後画像のセットになっていますが、今回はこの加工後の画像を使用します。加工後の写真に関していえば、普通のカラー写真よりも鮮やかな色彩の写真が多く、割と今回のタスクにもあってるかなと思いチョイスしました。データサイズが小さいバージョンであれば、データの容量も大きくなくダウンロードもある程度手頃です。
実際の白黒写真ですが、こちらはiPhoneの『Google PhotoScan』というアプリでデータ化しました。このアプリはググれば詳しく出てくるのですが、かなり優秀で写真屋さんに行かずとも結構綺麗にデータ化してくれます(しかも一瞬で)。元々の写真はかなり古く黄ばんでいたのですが、白黒画像に変換したところ通常の白黒画像とあまり見た目も変わらない感じでした。
画像サイズ
fivekにある画像も、実際にカラー化したい画像も正方形では無く長方形で、しかも縦横比がバラバラな感じです。そこで、以下の4パターンのいずれかを採用しようと考えました。
- 一律正方形にResize
- 元々のアスペクト比を変えず正方形にReisize、空白の領域は黒く塗りつぶす
- 元々のアスペクト比を変えず正方形にReisize、空白の領域に元々の画像の情報を入れ込む。
- 縦長の画像を回転させ横長に変換後、一律同じ長方形のサイズにリサイズ
- 正方形にCrop
とりあえず今回は3の手法でいくことにしました。1と4に関しては、画像の特徴を大きく変えてしまう可能性があること、2と3に関しては、3の方が情報量が多いかなと考えたことが採用の理由です。勿論素直に行くと5番かもしれませんが、本番用写真まで正方形Cropは少し嫌なので。
尚、出力結果は後処理で元々のアスペクト比にCropします。
学習開始
その1
学習その1です。
序盤暴れがちなところからスタートして徐々にDiscriminatorが強くなり、Generatorのロスが大きくなっていく様子が見て取れます。L1lossに関しては減少しているかどうかはっきりとは分からない感じです。(L1lossが小さい=本物に近い色合い≠本物っぽい
というとこがGANの評価の難しさではあります。)
その2
学習その2です。学習1では、Discriminatorが強くなる傾向があったので、次は以下の点を変えてみて、実験をすることにしました。
- D,及びGのAdversarial LossをBCEからHingeLossに変更
- DのBatchNormalizationをInstanceNormalizationに変更
- Dのweight更新頻度を半減 (Gの半分)
Cross entropy loss と Hinge loss
その2ではその1からの変更点として、その1で使用したbinary cross entropy から hinge loss に損失関数を変更します。その狙いはざっくり言うと、弱いロスにすることで片一方(D or G)のネットワークが強くなりすぎるのを防ぐ
ことにあります。
以下は、鉄板のクロスエントロピー(D version)です。
D:loss = -\sum_{i=1}^{n}\bigl(t_ilogy_i - (1-t_i) log(1-y_i) \bigl)
要はtarget=1
の時はその出力をできるだけ高く(最終層にsigmoidかますので、大きければ大きいほど1に近く)。target=0
の時はその出力を負の方向に大きくするように訓練すれば、ロスが小さくなる感じです。
逆にGについては上式を最大化させる様に訓練することになります。更にGの場合自ら生成した偽画像しか評価されないため、上式の左項が消えて、よりシンプルになります。
G:loss = \sum_{i=1}^{n}log\Bigl(1 - D\bigl(G(z_i))\bigl) \Bigl) (最大化)\\
= \sum_{i=1}^{n}log\Bigl(D\bigl(G(z_i)) - 1\bigl) \Bigl) (最小化)\\
= -\sum_{i=1}^{n}log\Bigl(D\bigl(G(z_i))\bigl) \Bigl) (これを最小化させると捉えることも可能)
上記は二パターンの最適化を出していますが、どちらの実装方法もある様です。
一方、ヒンジ損失ですが、このサイトが分かり易いかと思います。[参考]
サイトではSVM以外にはあまり使われないと書いてますが、現在こうして他の手法に活きてくるのが、面白いです。クロスエントロピーではtargetを(0,1)と表現しましたが、ヒンジでは(-1,1)と表現します。
t = ±1 \\
Loss = max(0, 1-t*y)
式のまんまですが、target=-1
のときはより小さな出力を、逆にtarget=1
の時は大きな出力を出せば損失が小さくて済みます。しかしクロスエントロピーと違って、ある程度でロスを0と打ち切ってしまうこともわかります。クロスエントロピーの場合、完全に(0,1)で予測しないといつまでたってもロスが消えませんが、ヒンジではそうではありません。これが弱いロス
と言われる所以です。PyTorchの実装は以下の通りです。ヒンジのところはDとGによるパターン分けが面倒なので、まとめてクラス化するのが良いかと思います。
# ones: 全ての値が1のpatch
# zeros: 全ての値が0のpatch
# Gloss(BCE)
loss = torch.nn.BCEWithLogitsLoss()(d_out_fake, ones)
# Gloss(Hinge)
loss = -1 * torch.mean(d_out_fake)
# Dloss(BCE)
loss_real = torch.nn.BCEWithLogitsLoss()(d_out_real, ones)
loss_fake = torch.nn.BCEWithLogitsLoss()(d_out_fake, zeros)
loss = loss_reak + loss_fake
# Dloss(Hinge)
loss_real = -1 * torch.mean(torch.min(d_out_real-1, zeros))
loss_fake = -1 * torch.mean(torch.min(-d_out_fake-1, zeros))
loss = loss_reak + loss_fake
Instance Normalization
Instance Normalization は Batch Normalizationの派生です。Batch Normalizationも含めその内容は下記の記事に良くまとまっています。
【GIF】初心者のためのCNNからバッチノーマライゼーションとその仲間たちまでの解説
BatchNormalizationとは一つのミニバッチに含まれるデータの同チャンネル同士で標準化処理を行いますが、InstanceNormalizationでは、ミニバッチ全体でなくデータ単体で行われます。要はバッチサイズ1のBatchNormalizationです。
例えばpix2pixの派生であるpix2pixHDなどでも使われていますが、その目的は勾配の増加を抑えることで学習を収束させにくくする
ことにあります。これをDに適用することでDとGのバランスをとることが主たる目的になります。
以下、結果になります。先ほどよりDiscriminatorの収束が明らかに遅くなっているのが分かります。また、L1lossの減少も先ほどより大きい様に感じます。
その3
学習その3は、学習1から以下の点を変更しました。
- 基本は学習1を踏襲
- 学習率を元論文に揃える。 (1e-4 -> 2e-4)
- Schedulerを使った学習率の調整を削除
- 画像サイズを320から256に変更 (論文を踏襲)
- PatchGANの領域数を10x10から4x4に変更
- trainの際のaugmentationにおけるBlurを削除
しかしながら、殆ど1と変わらない結果に終わってしまいました。
その4
今までの結果を見るに、自然の風景は割とカラー化できるものの、人物に関しては全然上手くいかない
という傾向が見えてきました。本来の目的上これでは意味が無いので、改めて自分でデータを取得してくることに。尚、データの集め方については、以前のQiita記事でも書いたのですが、BingImageSearchを使用しました。[Qiita記事のリンク]
尚、その2をベースにしていますが、若干いじっています。それとデータ数がおよそ3倍(13,000枚超)程度になったため、200epochを100epochに減らしています。
# 学習1からの変更点
- 学習データの追加 (fivek -> fivek+人物画像)
- epoch数を200 -> 97
- D,及GのAdversarial LossをBCEからHingeLossに変更 (その2と一緒)
- DのBatchNormalizationをInstanceNormalizationに変更 (その2と一緒)
- Dのweight更新頻度を半減 (その2と一緒)
- Schedulerを使った学習率の調整を削除 (その3と一緒)
- 学習率を 1e-4->2e-4 に変更 (その3と一緒)
# 出力の画素数が少し物足りなかっただけの理由、本当は良くないかも
- 画像サイズの変更(320->352) (<-new)
- 上記に伴い、PatchGANの領域数を(10,10)->(11,11)に変更 (<-new)
以下、結果です。その2と比べ、終盤の上下動が大きいのはschedulerを削除したのが1番の原因だと思われます。
まとめ。そして肝心の写真の結果
色々、試行錯誤しましたが風景画像はかなりの割合で違和感なく着色できるのに対し、人物画像や派手なコントラストの画像(例; 人・服・花・人工物...など)はあまり着色が進まないといった傾向が分かりました。
(左:fake画像 右:real画像)
これはこれで逆に神秘的に感じてくる様な気がします。眺めていると面白い。また肝心の本番用のmy写真もしっかりカラー化はしてる様でした(笑)一安心。。。