機械学習
行列分解
MNIST

MNIST画像をちょっと分解してみる

はじめに

先週、某BISに参加しました。チュートリアルが興味深かったので、
興奮が冷めやらぬうちに、MNIST画像を分解してみたいと思います。

考え方

MNISTの画像は、高さ28ピクセル、幅28ピクセルです。そこで・・・

高さ方向の28ピクセルのうちの$i$番目に(つまり、画像を行列と見たときの$i$行目に)、
$d$次元のベクトル$\boldsymbol{a}_i$, $i = 1, \ldots, 28$を割り当てます。

幅方向の28ピクセルのうちの$j$番目に(つまり、画像を行列と見たときの$j$列目に)、
やはり$d$次元のベクトル$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を割り当てます。

そして、MNISTの画像のうちの$n$番目に、$d \times d$の行列$\boldsymbol{Z}_n$を割り当てます。

$d$は、画像の縦横のサイズである28より小さくします

そして、MNIST画像のうちの$n$番目の画像を現す行列を$\boldsymbol{X}_n$とし、
この行列を、以下の式によって分解して近似することを考えます。

\boldsymbol{X}_{n,i,j} \approx \boldsymbol{a}_i^T \boldsymbol{Z}_n \boldsymbol{b}_j
= \sum_{\beta=1}^d \sum_{\alpha=1}^d \boldsymbol{a}_{i,\alpha} \boldsymbol{Z}_{n,\alpha, \beta} \boldsymbol{b}_{j,\beta}

$\boldsymbol{X}_{n,i,j}$は、行列$\boldsymbol{X}_n$の第$(i,j)$要素を表します。
具体的には、MNISTの$n$番目の画像の第$(i,j)$ピクセルの輝度値です。

重要なのは、$\boldsymbol{a}_i$と$\boldsymbol{b}_j$については、すべての画像で同じものを使うということです。

つまり、これら$\boldsymbol{a}_i$と$\boldsymbol{b}_j$を使って、MNISTの各画像を$d \times d$という小さな行列$\boldsymbol{Z}_n$で表わし直そう、というわけです。
この行列$\boldsymbol{Z}_n$が、もとの画像の潜在的な表現のようなものになります。

なお、$\boldsymbol{a}_i$を列ベクトルとして持つ行列を$\boldsymbol{A}$、$\boldsymbol{b}_j$を列ベクトルとして持つ行列を$\boldsymbol{B}$とすると、
上の式は、すべての$i,j$についてまとめてしまえば、次のように書けます。

\boldsymbol{X}_{n} \approx \boldsymbol{A}^T \boldsymbol{Z}_n \boldsymbol{B}

このように、潜在表現$\boldsymbol{Z}_n$を使って、元の$\boldsymbol{X}_{n}$を分解しています。
繰り返しになりますが、$\boldsymbol{A}$と$\boldsymbol{B}$は、全画像で共通です。

実装

PyTorchで実装しました。
訓練データを使って$\boldsymbol{a}_i$, $i = 1, \ldots, 28$と$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を学習し、
その結果を使って、テストデータから選んだ100の画像を再構成してみます。

main.py
import math
import numpy
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

torch.manual_seed(123)

batch_size = 100
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=batch_size, shuffle=True)

test_batch_size = 100
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=test_batch_size, shuffle=True)
test_data, _ = next(iter(test_loader))
test_data = Variable(test_data.squeeze(), requires_grad=False)

height, width = test_data[0].size()
latent_dim = 10

A = Variable(torch.rand(height, latent_dim), requires_grad=True)
B = Variable(torch.rand(latent_dim, width), requires_grad=True)


train_iterations = 50
test_iterations = 500
learning_rate = 0.01

for batch_idx, (data, _) in enumerate(train_loader):
    batch_idx += 1
    data = Variable(data.squeeze(), requires_grad=False)
    Z = Variable(torch.rand(latent_dim, batch_size, latent_dim), requires_grad=True)
    optimizer = optim.Adam([A, B, Z], lr=learning_rate)

    for _ in range(train_iterations):
        X = Variable(torch.zeros(data.size()), requires_grad=False)
        for k in range(latent_dim):
            temp = torch.mm(A, Z[:, :, k])
            for h in range(height):
                X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])

        optimizer.zero_grad()
        loss = (data - X).pow(2).mean()
        loss.backward()
        optimizer.step()
    print('train {} : {:.6f}'.format(batch_idx, loss.data[0]))

    # test
    if batch_idx % 1 == 0:
        Z = Variable(torch.rand(latent_dim, test_batch_size, latent_dim), requires_grad=True)
        test_optimizer = optim.Adam([Z], lr=learning_rate)

        for _ in range(test_iterations):
            X = Variable(torch.zeros(test_data.size()), requires_grad=False)
            for k in range(latent_dim):
                temp = torch.mm(A, Z[:, :, k])
                for h in range(height):
                    X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])

            test_optimizer.zero_grad()
            diff = test_data - X
            #diff[:, 10:18, 10:18] = 0.0
            loss = diff.pow(2).mean()
            loss.backward()
            test_optimizer.step()
        print('test {} : {:.6f}'.format(batch_idx, loss.data[0]))

        X = X.data.numpy()
        X[X < 0.0] = 0.0
        X[X > 1.0] = 1.0
        f, axarr = plt.subplots(2 * math.ceil(test_batch_size / 10), 10)
        for i in range(test_batch_size):
            ir, ic = i // 10, i % 10
            axarr[ir*2, ic].imshow(test_data[i].data.numpy(), norm=colors.NoNorm())
            axarr[ir*2, ic].axis('off')
            axarr[ir*2+1, ic].imshow(X[i], norm=colors.NoNorm())
            axarr[ir*2+1, ic].axis('off')
        plt.savefig('e{}.png'.format(str(batch_idx).zfill(4)), bbox_inches='tight')
        plt.close()

訓練データについては、元画像と再構成画像の差の2乗を損失関数として、
$\boldsymbol{A}$、$\boldsymbol{B}$、そして、$\boldsymbol{Z}_n$ (これは各画像ごとに別々)、
以上3種類の変数を、普通のNNのように、ミニバッチ学習で更新しています。

その一方、テストデータについては、$\boldsymbol{A}$、$\boldsymbol{B}$は固定し、各画像について$\boldsymbol{Z}_n$だけを最適化しています。
つまり、訓練データで学習した結果は固定し、各テスト画像の潜在的な表現を$d \times d$行列として求めています。

結果

上のコードでは、訓練データ100個ごとに、テスト用画像100個の再構成を試みています。

まず、ひとつのミニバッチ(サイズは100)で学習するごとに、
同じテスト画像100枚の再構成がどのように変化するかをanimated gifで示します。
20個目のミニバッチの学習が終わるまでの変化です。
100枚の各テスト画像のすぐ下に、再構成した画像を示しています。
tt.gif
最初の数ステップはうまく再構成できていませんが、その後は安定します。
それなりにうまく再構成できているようです。

これでは面白くないので、上のコードの「#diff[:, 10:18, 10:18] = 0.0」という行を、
「#」を消してコメントではなくして、動かしてみます。

こうすると、テスト用画面の中央部8×8ピクセルの損失がゼロになり、
この部分からは誤差が伝播されなくなります。

つまり、テスト用画像の中央の8×8ピクセル以外の場所で推定した$\boldsymbol{Z}_n$を使って、
この穴を補完するような再構成を実行できます。結果をanimated gifで示します。
ett.gif
やはり最初の数ステップは再構成がうまくいっていません。これは先ほどと同じです。
そして問題の穴を開けた中央部ですが・・・学習を進めていっても、ふらふらしたままのようです。

せっかくなので、この穴を開けたケースについて、ミニバッチを100個見た状態までの変化を
animated gifにしました。最初の20ステップは、さきほどのanimated gifと全く同じです。
ett_long.gif
やはり、中央の8×8ピクセルの部分は、ずっとふらふらしたままです。

まとめ

テスト用画像の中央に穴を開けた再構成を見る限り、
こういった問題を解くときに、ミニバッチ学習はあまり良くないのかもしれません。
どう見ても、そのつど使ったミニバッチに、直後の再構成が影響されているように見えます。
どうすればいいか・・・これから考えてみます。

参考文献

Khrulkov et al. Expressive power of recurrent neural networks. arXiv:1711.00811の第2節と第3節