LoginSignup
1
1

More than 5 years have passed since last update.

MNIST画像をちょっと分解してみる

Posted at

はじめに

先週、某BISに参加しました。チュートリアルが興味深かったので、
興奮が冷めやらぬうちに、MNIST画像を分解してみたいと思います。

考え方

MNISTの画像は、高さ28ピクセル、幅28ピクセルです。そこで・・・

高さ方向の28ピクセルのうちの$i$番目に(つまり、画像を行列と見たときの$i$行目に)、
$d$次元のベクトル$\boldsymbol{a}_i$, $i = 1, \ldots, 28$を割り当てます。

幅方向の28ピクセルのうちの$j$番目に(つまり、画像を行列と見たときの$j$列目に)、
やはり$d$次元のベクトル$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を割り当てます。

そして、MNISTの画像のうちの$n$番目に、$d \times d$の行列$\boldsymbol{Z}_n$を割り当てます。

$d$は、画像の縦横のサイズである28より小さくします

そして、MNIST画像のうちの$n$番目の画像を現す行列を$\boldsymbol{X}_n$とし、
この行列を、以下の式によって分解して近似することを考えます。

\boldsymbol{X}_{n,i,j} \approx \boldsymbol{a}_i^T \boldsymbol{Z}_n \boldsymbol{b}_j
= \sum_{\beta=1}^d \sum_{\alpha=1}^d \boldsymbol{a}_{i,\alpha} \boldsymbol{Z}_{n,\alpha, \beta} \boldsymbol{b}_{j,\beta}

$\boldsymbol{X}_{n,i,j}$は、行列$\boldsymbol{X}_n$の第$(i,j)$要素を表します。
具体的には、MNISTの$n$番目の画像の第$(i,j)$ピクセルの輝度値です。

重要なのは、$\boldsymbol{a}_i$と$\boldsymbol{b}_j$については、すべての画像で同じものを使うということです。

つまり、これら$\boldsymbol{a}_i$と$\boldsymbol{b}_j$を使って、MNISTの各画像を$d \times d$という小さな行列$\boldsymbol{Z}_n$で表わし直そう、というわけです。
この行列$\boldsymbol{Z}_n$が、もとの画像の潜在的な表現のようなものになります。

なお、$\boldsymbol{a}_i$を列ベクトルとして持つ行列を$\boldsymbol{A}$、$\boldsymbol{b}_j$を列ベクトルとして持つ行列を$\boldsymbol{B}$とすると、
上の式は、すべての$i,j$についてまとめてしまえば、次のように書けます。

\boldsymbol{X}_{n} \approx \boldsymbol{A}^T \boldsymbol{Z}_n \boldsymbol{B}

このように、潜在表現$\boldsymbol{Z}_n$を使って、元の$\boldsymbol{X}_{n}$を分解しています。
繰り返しになりますが、$\boldsymbol{A}$と$\boldsymbol{B}$は、全画像で共通です。

実装

PyTorchで実装しました。
訓練データを使って$\boldsymbol{a}_i$, $i = 1, \ldots, 28$と$\boldsymbol{b}_j$, $j = 1, \ldots, 28$を学習し、
その結果を使って、テストデータから選んだ100の画像を再構成してみます。

main.py
import math
import numpy
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

torch.manual_seed(123)

batch_size = 100
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=batch_size, shuffle=True)

test_batch_size = 100
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, download=True,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=test_batch_size, shuffle=True)
test_data, _ = next(iter(test_loader))
test_data = Variable(test_data.squeeze(), requires_grad=False)

height, width = test_data[0].size()
latent_dim = 10

A = Variable(torch.rand(height, latent_dim), requires_grad=True)
B = Variable(torch.rand(latent_dim, width), requires_grad=True)


train_iterations = 50
test_iterations = 500
learning_rate = 0.01

for batch_idx, (data, _) in enumerate(train_loader):
    batch_idx += 1
    data = Variable(data.squeeze(), requires_grad=False)
    Z = Variable(torch.rand(latent_dim, batch_size, latent_dim), requires_grad=True)
    optimizer = optim.Adam([A, B, Z], lr=learning_rate)

    for _ in range(train_iterations):
        X = Variable(torch.zeros(data.size()), requires_grad=False)
        for k in range(latent_dim):
            temp = torch.mm(A, Z[:, :, k])
            for h in range(height):
                X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])

        optimizer.zero_grad()
        loss = (data - X).pow(2).mean()
        loss.backward()
        optimizer.step()
    print('train {} : {:.6f}'.format(batch_idx, loss.data[0]))

    # test
    if batch_idx % 1 == 0:
        Z = Variable(torch.rand(latent_dim, test_batch_size, latent_dim), requires_grad=True)
        test_optimizer = optim.Adam([Z], lr=learning_rate)

        for _ in range(test_iterations):
            X = Variable(torch.zeros(test_data.size()), requires_grad=False)
            for k in range(latent_dim):
                temp = torch.mm(A, Z[:, :, k])
                for h in range(height):
                    X[:, :, h] = X[:, :, h] + torch.ger(temp[h, :], B[k, :])

            test_optimizer.zero_grad()
            diff = test_data - X
            #diff[:, 10:18, 10:18] = 0.0
            loss = diff.pow(2).mean()
            loss.backward()
            test_optimizer.step()
        print('test {} : {:.6f}'.format(batch_idx, loss.data[0]))

        X = X.data.numpy()
        X[X < 0.0] = 0.0
        X[X > 1.0] = 1.0
        f, axarr = plt.subplots(2 * math.ceil(test_batch_size / 10), 10)
        for i in range(test_batch_size):
            ir, ic = i // 10, i % 10
            axarr[ir*2, ic].imshow(test_data[i].data.numpy(), norm=colors.NoNorm())
            axarr[ir*2, ic].axis('off')
            axarr[ir*2+1, ic].imshow(X[i], norm=colors.NoNorm())
            axarr[ir*2+1, ic].axis('off')
        plt.savefig('e{}.png'.format(str(batch_idx).zfill(4)), bbox_inches='tight')
        plt.close()

訓練データについては、元画像と再構成画像の差の2乗を損失関数として、
$\boldsymbol{A}$、$\boldsymbol{B}$、そして、$\boldsymbol{Z}_n$ (これは各画像ごとに別々)、
以上3種類の変数を、普通のNNのように、ミニバッチ学習で更新しています。

その一方、テストデータについては、$\boldsymbol{A}$、$\boldsymbol{B}$は固定し、各画像について$\boldsymbol{Z}_n$だけを最適化しています。
つまり、訓練データで学習した結果は固定し、各テスト画像の潜在的な表現を$d \times d$行列として求めています。

結果

上のコードでは、訓練データ100個ごとに、テスト用画像100個の再構成を試みています。

まず、ひとつのミニバッチ(サイズは100)で学習するごとに、
同じテスト画像100枚の再構成がどのように変化するかをanimated gifで示します。
20個目のミニバッチの学習が終わるまでの変化です。
100枚の各テスト画像のすぐ下に、再構成した画像を示しています。
tt.gif
最初の数ステップはうまく再構成できていませんが、その後は安定します。
それなりにうまく再構成できているようです。

これでは面白くないので、上のコードの「#diff[:, 10:18, 10:18] = 0.0」という行を、
「#」を消してコメントではなくして、動かしてみます。

こうすると、テスト用画面の中央部8×8ピクセルの損失がゼロになり、
この部分からは誤差が伝播されなくなります。

つまり、テスト用画像の中央の8×8ピクセル以外の場所で推定した$\boldsymbol{Z}_n$を使って、
この穴を補完するような再構成を実行できます。結果をanimated gifで示します。
ett.gif
やはり最初の数ステップは再構成がうまくいっていません。これは先ほどと同じです。
そして問題の穴を開けた中央部ですが・・・学習を進めていっても、ふらふらしたままのようです。

せっかくなので、この穴を開けたケースについて、ミニバッチを100個見た状態までの変化を
animated gifにしました。最初の20ステップは、さきほどのanimated gifと全く同じです。
ett_long.gif
やはり、中央の8×8ピクセルの部分は、ずっとふらふらしたままです。

まとめ

テスト用画像の中央に穴を開けた再構成を見る限り、
こういった問題を解くときに、ミニバッチ学習はあまり良くないのかもしれません。
どう見ても、そのつど使ったミニバッチに、直後の再構成が影響されているように見えます。
どうすればいいか・・・これから考えてみます。

参考文献

Khrulkov et al. Expressive power of recurrent neural networks. arXiv:1711.00811の第2節と第3節

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1