はじめに
先週、某BISに参加しました。チュートリアルが興味深かったので、
興奮が冷めやらぬうちに、MNIST画像を分解してみたいと思います。
考え方
MNISTの画像は、高さ28ピクセル、幅28ピクセルです。そこで・・・
高さ方向の28ピクセルのうちの$i$番目に(つまり、画像を行列と見たときの$i$行目に)、
$d$次元のベクトル$\boldsymbol{a}_i$, $i = 1, \ldots, 28$を割り当てます。
幅方向の28ピクセルのうちの$j$番目に(つまり、画像を行列と見たときの$j$列目に)、
やはり$d$次元のベクトル$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を割り当てます。
そして、MNISTの画像のうちの$n$番目に、$d \times d$の行列$\boldsymbol{Z}_n$を割り当てます。
$d$は、画像の縦横のサイズである28より小さくします。
そして、MNIST画像のうちの$n$番目の画像を現す行列を$\boldsymbol{X}_n$とし、
この行列を、以下の式によって分解して近似することを考えます。
\boldsymbol{X}_{n,i,j} \approx \boldsymbol{a}_i^T \boldsymbol{Z}_n \boldsymbol{b}_j
= \sum_{\beta=1}^d \sum_{\alpha=1}^d \boldsymbol{a}_{i,\alpha} \boldsymbol{Z}_{n,\alpha, \beta} \boldsymbol{b}_{j,\beta}
$\boldsymbol{X}_{n,i,j}$は、行列$\boldsymbol{X}_n$の第$(i,j)$要素を表します。
具体的には、MNISTの$n$番目の画像の第$(i,j)$ピクセルの輝度値です。
重要なのは、$\boldsymbol{a}_i$と$\boldsymbol{b}_j$については、すべての画像で同じものを使うということです。
つまり、これら$\boldsymbol{a}_i$と$\boldsymbol{b}_j$を使って、MNISTの各画像を$d \times d$という小さな行列$\boldsymbol{Z}_n$で表わし直そう、というわけです。
この行列$\boldsymbol{Z}_n$が、もとの画像の潜在的な表現のようなものになります。
なお、$\boldsymbol{a}_i$を列ベクトルとして持つ行列を$\boldsymbol{A}$、$\boldsymbol{b}_j$を列ベクトルとして持つ行列を$\boldsymbol{B}$とすると、
上の式は、すべての$i,j$についてまとめてしまえば、次のように書けます。
\boldsymbol{X}_{n} \approx \boldsymbol{A}^T \boldsymbol{Z}_n \boldsymbol{B}
このように、潜在表現$\boldsymbol{Z}_n$を使って、元の$\boldsymbol{X}_{n}$を分解しています。
繰り返しになりますが、$\boldsymbol{A}$と$\boldsymbol{B}$は、全画像で共通です。
実装
PyTorchで実装しました。
訓練データを使って$\boldsymbol{a}_i$, $i = 1, \ldots, 28$と$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を学習し、
その結果を使って、テストデータから選んだ100の画像を再構成してみます。
import math
import numpy
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
torch.manual_seed(123)
batch_size = 100
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()])),
batch_size=batch_size, shuffle=True)
test_batch_size = 100
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor()])),
batch_size=test_batch_size, shuffle=True)
test_data, _ = next(iter(test_loader))
test_data = Variable(test_data.squeeze(), requires_grad=False)
height, width = test_data[0].size()
latent_dim = 10
A = Variable(torch.rand(height, latent_dim), requires_grad=True)
B = Variable(torch.rand(latent_dim, width), requires_grad=True)
train_iterations = 50
test_iterations = 500
learning_rate = 0.01
for batch_idx, (data, _) in enumerate(train_loader):
batch_idx += 1
data = Variable(data.squeeze(), requires_grad=False)
Z = Variable(torch.rand(latent_dim, batch_size, latent_dim), requires_grad=True)
optimizer = optim.Adam([A, B, Z], lr=learning_rate)
for _ in range(train_iterations):
X = Variable(torch.zeros(data.size()), requires_grad=False)
for k in range(latent_dim):
temp = torch.mm(A, Z[:, :, k])
for h in range(height):
X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])
optimizer.zero_grad()
loss = (data - X).pow(2).mean()
loss.backward()
optimizer.step()
print('train {} : {:.6f}'.format(batch_idx, loss.data[0]))
# test
if batch_idx % 1 == 0:
Z = Variable(torch.rand(latent_dim, test_batch_size, latent_dim), requires_grad=True)
test_optimizer = optim.Adam([Z], lr=learning_rate)
for _ in range(test_iterations):
X = Variable(torch.zeros(test_data.size()), requires_grad=False)
for k in range(latent_dim):
temp = torch.mm(A, Z[:, :, k])
for h in range(height):
X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])
test_optimizer.zero_grad()
diff = test_data - X
#diff[:, 10:18, 10:18] = 0.0
loss = diff.pow(2).mean()
loss.backward()
test_optimizer.step()
print('test {} : {:.6f}'.format(batch_idx, loss.data[0]))
X = X.data.numpy()
X[X < 0.0] = 0.0
X[X > 1.0] = 1.0
f, axarr = plt.subplots(2 * math.ceil(test_batch_size / 10), 10)
for i in range(test_batch_size):
ir, ic = i // 10, i % 10
axarr[ir*2, ic].imshow(test_data[i].data.numpy(), norm=colors.NoNorm())
axarr[ir*2, ic].axis('off')
axarr[ir*2+1, ic].imshow(X[i], norm=colors.NoNorm())
axarr[ir*2+1, ic].axis('off')
plt.savefig('e{}.png'.format(str(batch_idx).zfill(4)), bbox_inches='tight')
plt.close()
訓練データについては、元画像と再構成画像の差の2乗を損失関数として、
$\boldsymbol{A}$、$\boldsymbol{B}$、そして、$\boldsymbol{Z}_n$ (これは各画像ごとに別々)、
以上3種類の変数を、普通のNNのように、ミニバッチ学習で更新しています。
その一方、テストデータについては、$\boldsymbol{A}$、$\boldsymbol{B}$は固定し、各画像について$\boldsymbol{Z}_n$だけを最適化しています。
つまり、訓練データで学習した結果は固定し、各テスト画像の潜在的な表現を$d \times d$行列として求めています。
結果
上のコードでは、訓練データ100個ごとに、テスト用画像100個の再構成を試みています。
まず、ひとつのミニバッチ(サイズは100)で学習するごとに、
同じテスト画像100枚の再構成がどのように変化するかをanimated gifで示します。
20個目のミニバッチの学習が終わるまでの変化です。
100枚の各テスト画像のすぐ下に、再構成した画像を示しています。
最初の数ステップはうまく再構成できていませんが、その後は安定します。
それなりにうまく再構成できているようです。
これでは面白くないので、上のコードの「#diff[:, 10:18, 10:18] = 0.0」という行を、
「#」を消してコメントではなくして、動かしてみます。
こうすると、テスト用画面の中央部8×8ピクセルの損失がゼロになり、
この部分からは誤差が伝播されなくなります。
つまり、テスト用画像の中央の8×8ピクセル以外の場所で推定した$\boldsymbol{Z}_n$を使って、
この穴を補完するような再構成を実行できます。結果をanimated gifで示します。
やはり最初の数ステップは再構成がうまくいっていません。これは先ほどと同じです。
そして問題の穴を開けた中央部ですが・・・学習を進めていっても、ふらふらしたままのようです。
せっかくなので、この穴を開けたケースについて、ミニバッチを100個見た状態までの変化を
animated gifにしました。最初の20ステップは、さきほどのanimated gifと全く同じです。
やはり、中央の8×8ピクセルの部分は、ずっとふらふらしたままです。
まとめ
テスト用画像の中央に穴を開けた再構成を見る限り、
こういった問題を解くときに、ミニバッチ学習はあまり良くないのかもしれません。
どう見ても、そのつど使ったミニバッチに、直後の再構成が影響されているように見えます。
どうすればいいか・・・これから考えてみます。
参考文献
Khrulkov et al. Expressive power of recurrent neural networks. arXiv:1711.00811の第2節と第3節