Flow-basedモデルの仲間であるi-RevNet/i-ResNetで遊んでみました。深層生成モデルのFlowが何をやっているのか、大きな特徴である「可逆性」について、イメージを掴むことを目標とします。
Flow-basedモデルとは
生成モデルの一つ。生成モデルにはおもに3種類あります。
GAN(Generative adversarial networks)
DとGの2つのネットワークが敵対的に学習し、ナッシュ均衡に収束するように訓練
- 長所:画像ではとにかく高画質な生成ができる。研究も多い。Dの本物/偽物で損失を定義するので、ダイナミックな損失関数が可能(教師あり学習のような損失関数芸に陥りにくい)。
- 短所:訓練が成功する保証がない。しかし、最近の研究ではかなり安定してきた。
VAE(Variational Auto Encoder)
Evidence lower bound (ELBO) の最大化。Auto EncoderにReparameterization Trickを加えたもの。潜在空間を正規分布で仮定することがほとんど。
- 長所:ニューラルネットワークが1個なので訓練が(ほぼ)必ず成功する。モード崩壊を気にする必要がない
-
短所:GANのような潜在空間滑らかさと画質の両立が困難。画質面で明らかに劣ることが多い。しかし、潜在空間の滑らかさを捨てれば、最先端のGANを凌駕する画質は出せるとの研究もある(VQ-VAE2など)。
VAEだけでなく、ニューラルネットワークが1個の場合は損失関数が静的で、タスクによっては(Image to imageなど)ハイパラ沼になることも多い。
Flow-basedモデル(今回の内容)
正規分布のような簡単な分布を積み重ね、尤度最大化問題を解くことでデータの密度を推定するモデル。GANにもVAEにもない大きな特徴として、ネットワークを全体が可逆であるという制約がある。
- 長所: 可逆であること。画像→特徴量への変換だけでなく、特徴量→画像の変換も行える。そして、可逆の制約により、特徴量→画像の逆変換は、画像→特徴量の逆変換になることが保証される。
- 短所:ヤコビ行列の計算量が重い($O(D^3)$)。計算量を削るために特殊な制約をおいたり、アルゴリズムを工夫したりする必要があり、理論的に難しくなりがち。
理論的な内容は、こちらの記事やこちらのスライドに詳しく書かれているので、ぜひ参照してみてください。
画像はi-ResNetの論文より
もっと簡単に言うと
ニューラルネットワーク全体を1つの関数と考えると、
$$z=\cal{F}(x) = f_n(f_{n-1}(\cdots(f_1(x)))) $$
と書くことができます。ここで、$f_i$はニューラルネットワークの$i$番目の層、$n$は層の数に対応し、$x$は入力画像を表します。ここで行っているのは、「画像→特徴量」への変換です(特徴量とは、どのクラスに属するかの確率の推定値など)。この逆変換を考えましょう。
$$\cal{F}^{-1}(z) = f_1^{-1}(f_2^{-1}(\cdots(f_n^{-1}(z))))$$
ニューラルネットワークが可逆であるFlow-basedモデルでは、
$$x={\cal F}^{-1}(z)={\cal F^{-1}}({\cal F}(x)) $$
が成り立ちます。簡単に言うと、ニューラルネットワークを行って戻ってきたら元の画像に戻ってきているということです。逆変換や逆関数が計算できるニューラルネットワークと考えておくといいでしょう。
「あれ、これAuto Encoderと同じじゃない?」と思うかもしれません。Flow-basedモデルでは、Auto EncoderのようなL2ロスのような損失関数を必要としません。
また、Auto Encoderでは入力画像と復元画像は損失関数により限りなく近づきますが、各層に関しては可逆の保証はありません。Flow-basedモデルでは、
$$x = f_i^{-1}(f_i(x)) \qquad (i=1,2,\cdots,n)$$
のように各層の可逆性も保証されます。訓練済みモデルを分析したり、その逆変換や逆元を求めたい場合に便利ではないでしょうか。
i-RevNet / i-ResNet
ネットワーク構造を可逆にした畳み込みニューラルネットワーク。Flow-basedモデルと関係性が深く、仲間として扱われることが多いです。
「i-ResNetは可逆なResNet」ということで着想がわかりやすく(i-ResNetはi-RevNetの改良)、PyTorchのコードが公開されているので、Flow-basedモデルの中ではとっつきやすい方ではないでしょうか。今回は公開されているコードで遊んでみます。なぜ可逆になるかは今回は省略します。
i-RevNet : https://github.com/jhjacobsen/pytorch-i-revnet
i-ResNet : https://github.com/jhjacobsen/invertible-resnet
今回はi-RevNetの例で解説します。
CIFAR-10を訓練
公開されているコードをcloneしてCIFAR-10を訓練してみます。
git clone https://github.com/jhjacobsen/pytorch-i-revnet .
cloneが終わったら訓練します。
python CIFAR_main.py --nBlocks 18 18 18 --nStrides 1 2 2 --nChannels 16 64 256
デフォルトがバッチサイズ128、学習率0.1で、Validation accuracyが94.5%出るそうです。自分は高速化のためにバッチサイズ512、学習率0.4で訓練しました。それでも93%以上出ているのでだいたいあっているでしょう。
2080Tiが2枚で3時間弱ぐらいだったので、CIFAR-10ならColabでも十分訓練できると思います。
訓練済みモデルは「checkpoint/cifar10」以下に保存されています。
復元画像を見る
画像→特徴量→画像という逆変換を行った復元(reconstruction)画像を見てみましょう。コードは末尾を参照してください。main関数の中にあります。
各画像の左が本物、右が復元画像です。ぱっと見違いがわからないですね。
クラス間の補間
Flowらしいことをやってみましょう。i-RevNetでは$z\to x$の逆変換ができるので、2つの画像のEmbeddingを取り、線形補間した特徴量を作り逆変換することで、潜在空間上の画像補間を行うことができます。
イメージ的にはGANの潜在空間の補間とほぼ一緒です。まずはクラス間(異なるクラス間)の補間を行ってみます。異なる10個のクラスから画像を1枚ずつ抽出し、その中から2枚のペアを作って補間しています。どんな画像が出てくるでしょうか?
(拡大してみてください)。[1~10のクラスの画像]+[1~10のクラスの画像]で潜在空間上で補間したものです。
なんかこんな画像見たことありますよね。Data AugmentationのMixupとほぼ同じような結果が出てくるのです。
クラス内の補間
同様にクラス内で補間をしてみます。
やっぱりMixupっぽい
Mixupの場合
ではMixupの場合はどうでしょうか?Mixupの場合はただ2枚の画像を線形補間すればいいだけで、
$$x' = kx_1 + (1-k)x_2 \qquad (0\leq k\leq 1)$$
クラス内のMixupは次のようになります。
Mixupとの比較
200%拡大してみました。正直Mixupとの違いがわからない……。
逆に見れば、このi-RevNetでの実験を通じてMixupの正当性が確かめられたということが言えます。つまり、入力画像を線形補間して足すと、潜在空間の値も補間されるため、クラス間の境界の値も訓練されるようになる、したがって汎化性能が上がる、ということでしょうか。
その他挫折した内容
i-RevNetを使えば、潜在空間をプロット→次元削減し、分類ミスが多いところをサンプリングし、削減した次元を戻し(PCAのようなinverse_transformができる次元削減が必要)、ニューラルネットワークの逆変換を行えば、分類ミスしやすいハードサンプルを生成することができます。
しかし、次元削減をした際の説明分散が低すぎるという問題に直面してしまいました(潜在特徴量が高次元すぎて、容易にプロット可能な2次元程度では到底説明できない)。次元を復元したときに特徴量がスカスカのエリアに飛ばされてしまい、意味のない画像が生成されてしまいます。
次元削減を工夫するのがポイントかと思いますが、うまい解決方法あったら教えて下さい。
まとめ
i-RevNetで遊んでみることで、Flow-basedモデルが何やっているかイメージを掴むことができました。モデルの定性評価や可視化など、分析向きの手法ではないかなと思います。可逆性の具体的な活用方法がいまいち思い浮かばないので、もしいいアイディアあったらお待ちしております。
コード
import torch
from torchvision import transforms
import torchvision
from models.utils_cifar import std, mean
def dataloader():
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean["cifar10"], std["cifar10"]),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=200, shuffle=False, num_workers=2)
return testloader
def main():
model = torch.load("checkpoint/cifar10/i-revnet-55.t7")["model"]
mean_v, std_v = torch.tensor(mean["cifar10"]).view(1, 3, 1, 1), torch.tensor(std["cifar10"]).view(1, 3, 1, 1)
model.eval()
loader = dataloader()
X, y = next(iter(loader))
with torch.no_grad():
## 入力画像をinverseする(reconstruction)
_, embedding = model(X[:100].cuda())
invert = model.module.inverse(embedding).cpu()
# 本物と結合して表示
join = torch.stack([X[:100], invert], dim=1).view(-1, 3, 32, 32)
join = join * std_v + mean_v
torchvision.utils.save_image(join, "reconstruction.png", nrow=10)
## 潜在空間でクラス間の補間をする
for i in range(10):
result = []
for j in range(10):
X1 = X[y == i][i:(i+1)]
X2 = X[y == j][i:(i+1)]
_, z1 = model(X1.cuda())
_, z2 = model(X2.cuda())
k = torch.arange(11, dtype=torch.float32) / 10.0
k = k.view(-1, 1, 1, 1).cuda()
interpolate = z1 * k + z2 * (1 - k)
invert = model.module.inverse(interpolate)
result.append(invert)
out = torch.cat(result, dim=0).cpu() * std_v + mean_v
torchvision.utils.save_image(out, "interclass_" + str(i) + ".png", nrow=11)
## 潜在空間でクラス内の補間をする
for i in range(10):
X_slice = X[y == i][:12]
_, z_slice = model(X_slice.cuda())
# 潜在空間の補間
z1 = z_slice[:6].view(6, 1, 512, 8, 8)
z2 = z_slice[6:].view(6, 1, 512, 8, 8)
k = torch.arange(11, dtype=torch.float32) / 10.0
k = k.view(1, -1, 1, 1, 1).cuda()
interpolate = z1 * k + z2 * (1 - k)
interpolate = interpolate.view(-1, 512, 8, 8)
out = model.module.inverse(interpolate).cpu()
out = out * std_v + mean_v
torchvision.utils.save_image(out, "innerclass_"+str(i)+".png", nrow=11)
# mixupの場合
def mixup():
mean_v, std_v = torch.tensor(mean["cifar10"]).view(1, 3, 1, 1), torch.tensor(std["cifar10"]).view(1, 3, 1, 1)
loader = dataloader()
X, y = next(iter(loader))
## クラス間のmixup
for i in range(10):
result = []
for j in range(10):
X1 = X[y == i][i:(i+1)]
X2 = X[y == j][i:(i+1)]
k = torch.arange(11, dtype=torch.float32) / 10.0
k = k.view(-1, 1, 1, 1)
interpolate = X1 * k + X2 * (1 - k)
result.append(interpolate)
out = torch.cat(result, dim=0).cpu() * std_v + mean_v
torchvision.utils.save_image(out, "interclass_mixup_" + str(i) + ".png", nrow=11)
## クラス内のmixup
for i in range(10):
X1 = X[y == i][:6].view(6, 1, 3, 32, 32)
X2 = X[y == i][6:12].view(6, 1, 3, 32, 32)
interpolate = X1 * k + X2 * (1 - k)
interpolate = interpolate.view(-1, 3, 32, 32)
interpolate = interpolate * std_v + mean_v
torchvision.utils.save_image(interpolate, "innerclass_mixup_"+str(i)+".png", nrow=11)
参考資料
深層生成モデルを巡る旅(1): Flowベース生成モデル
A First Step to Flow-Based Generative Models(日本語スライド)