こちらの記事でPULSEを動かしてみましたが、やはり中で何をしているのかが気になりますので、ソースコードを参考に、論文を理解してみます。
論文全部を理解するのは大変なので、ポイントだけ抑えて、5分で理解できるようにしてみます。
タイトル、作者
PULSE:生成モデルの潜在空間探査による自己教師あり写真アップサンプリング
Sachit Menon、Alexandru Damian、Shijia Hu、Nikhil Ravi、Cynthia Rudin
デューク大学
ノースカロライナ州ダーラム
{sachit.menon、alexandru.damian、shijia.hu、nikhil.ravi、cynthia.rudin} @ duke.edu
メノンさんとダミアンさんの共著です。
今までの超解像の問題
今までの超解像の手法では、細かい部分(周波数成分が高い)でぼやけるという問題がありました。
これは、ロスの計算で、正解画像と超解像画像の差をMSE(平均二乗誤差)で求めているためでした。
MSEを改良した方法も出てきてますが、ぼやけは改善されていないようです。
PULSEの手法
以下の3つの特徴があります。
- StyleGANを使用して超解像画像を生成し、そこからダウンスケールした画像を作成します
- 「正解画像-超解像画像」での比較ではなく、「入力画像-ダウンスケールした画像」で比較します
- StyleGANに渡す潜在変数を、高次元ガウス分布(球の表面)に従うようにします
StyleGANで画像を生成する際に使用する潜在変数を調整していきます。
超解像画像の生成のところはStyleGANをそのまま使うので、そこで画像がぼやけることはありません。
ロスの計算
2つのロスを使用します。
(難しい式は省略)
- MSE(平均二乗誤差)で、低解像度画像間の差を計算
- GEOCROSSで、潜在変数が高次元ガウス分布(球の表面)に乗るように調整
この2つのロスにそれぞれ重みを掛けて、全体のロス値としています。
(実装コードでは、L1ロスも指定できるようになっています)
なお、GEOCROSSは以下のように計算しています。
X = latent.view(-1, 1, 18, 512)
Y = latent.view(-1, 18, 1, 512)
A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt()
B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt()
D = 2*torch.atan2(A, B)
D = ((D.pow(2)*512).mean((1, 2))/8.).sum()
ここでは「latent」が潜在変数(18x512のランダムな値にガウス関数を掛けたもの)になります。
補足
- 学習時間はかかりませんが、空間探索するため、多少処理時間がかかります
- 使用しているStyleGANは「CelebA-HQ」で学習済みのモデルを使用しています
- ガウシアンの値もすでに計算済みのものを使用しています
- 実装コードのフレームワークはPyTorchを使用しています
最後に
ざっくりと書きすぎているため、間違った記述があるかもしれませんが、もし気づいた方がいらっしゃいましたらご指摘いただけるとありがたいです。