Help us understand the problem. What is going on with this article?

TSNE Grid で MNIST を調べてみた

はじめに

最近いろんなモデルを MNIST で試してみました。その中でちょっと気にしたものが、MNIST のことが実はあんまに把握してないですね。例えばある数字はどれぐらいの書き方があるのか、書き方の間ではどんな傾向があるのか。これらの印象は実は薄いですから、やっはり感性的に把握したいです。

本記事ではこの素晴らしい記事を参考にしながら、T-SNE gridを利用して MNIST をもっと感性的に調べて行きます。

準備

まずは利用しているモージュルを導入します。

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib import offsetbox, patheffects
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import cdist
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.utils import shuffle
from sklearn.preprocessing import minmax_scale
from lapjv import lapjv
from PIL import Image

rc = {
  'font.family': ['sans-serif'],
  'font.sans-serif': ['Open Sans', 'Arial Unicode MS'],
  'font.size': 12,
  'figure.figsize': (8, 6),
  'grid.linewidth': 0.5,
  'legend.fontsize': 10,
  'legend.frameon': True,
  'legend.framealpha': 0.6,
  'legend.handletextpad': 0.2,
  'lines.linewidth': 1,
  'axes.facecolor': '#fafafa',
  'axes.labelsize': 10,
  'axes.titlesize': 14,
  'axes.linewidth': 0.5,
  'xtick.labelsize': 10,
  'xtick.minor.visible': True,
  'ytick.labelsize': 10,
  'figure.titlesize': 14
}
sns.set('notebook', 'whitegrid', rc=rc)

def colorize(d, color, alpha=1.0):
  rgb = np.dstack((d,d,d)) * color
  return np.dstack((rgb, d * alpha)).astype(np.uint8)

colors = sns.color_palette('tab10')

MNIST データ

そして、MNIST も導入して、ランダムのサンプルを検証してみます。

data, target = fetch_openml('mnist_784', version=1, return_X_y=True)
data = data.astype('float32')
target = target.astype('uint8')

fig, ax = plt.subplots(figsize=(16,8))

size = 28
dim = (15,30)
n = dim[0] * dim[1]
rnd = np.random.permutation(data.shape[0])

img = np.zeros((size * dim[0], size * dim[1], 4),dtype='uint8')

for d, t, i in zip(data[rnd[:n]], target[rnd[:n]], range(n)):
  ix, iy = divmod(i, dim[1])
  img[ix*size:(ix+1)*size,iy*size:(iy+1)*size,:] = colorize(d.reshape(size,size), colors[t], 0.9)

ax.imshow(img)
ax.set_axis_off()
plt.show()

mnist_tsne_01.png

悪くないですね。次は各クラスのキャラクターの画像を平均化して、ホットスポットを検証してみます。

fig, ax = plt.subplots(figsize=(10,3))

img = np.zeros((28,28*10))

for i in range(0,10):
  img[:,i*28:(i+1)*28] = np.mean(data[np.argwhere(target==i)],axis=0).reshape((28,28))

ax.imshow(img, cmap='RdBu_r')
ax.set_axis_off()
plt.show()

mnist_tsne_02.png

T-SNE で2次元にマッピング

そして次は、T-SNE を利用して、スタイル特徴を検出し、平面にマッピングする。

size = 50
n = size * size
x_data, y_data = shuffle(data, target, n_samples=n)
x_pca = PCA(n_components=50).fit_transform(x_data/255)
embeddings = TSNE(perplexity=50, random_state=24680, verbose=2).fit_transform(x_pca)
embeddings = minmax_scale(embeddings)

fig, ax = plt.subplots(figsize=(12,12))

for i in range(10):
  ax.scatter(embeddings[y_data==i,0],embeddings[y_data==i,1],cmap='tab10',marker='o',alpha=0.7,label=i)
  x_,y_ = np.median(embeddings[y_data==i,:],axis=0)
  txt = ax.text(x_,y_,str(i),fontsize=30)
  txt.set_path_effects([
    patheffects.Stroke(linewidth=7, foreground="w"),
    patheffects.Normal()
  ])

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.legend(loc='lower right')

plt.show()

mnist_tsne_03.png

大体のクラスターがわけられますね。じゃそれぞれのデータポイントの画像も表示してみよう。

fig, ax = plt.subplots(figsize=(16,16))

source = zip(embeddings, x_data.reshape((-1,28,28)), y_data)

for pos, d, i in source:
  img = colorize(d, colors[i], 0.5)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

handles = [mlines.Line2D([0], [0], label=i,
                         linewidth=0, marker='o', alpha=0.5,
                         markersize=7,markerfacecolor=colors[i],markeredgewidth=0)
           for i in range(10)]
ax.legend(loc='lower right',handles=handles)
plt.show()

mnist_tsne_04.png

なんか風が吹いてるように見えて面白いけど、数字自体は重ねて見にくいですね。
じゃ次はこの記事が示すように、src-d/lapjvを利用して、数字はグリットに排列してみます。

まずは、各データポイントの移動方向検証します。

grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost = cdist(grid, embeddings, 'sqeuclidean').astype('float32')
cost *= 1e7 / cost.max()

_, col_asses, _ = lapjv(cost)
grid_jv = grid[col_asses]

fig, ax = plt.subplots(figsize=(16, 16))

for start, end, i in zip(embeddings, grid_jv, y_data):
  plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
            color=colors[i], linewidth=0.05,
            head_length=0.005, head_width=0.005, alpha=0.7)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
handles = [mlines.Line2D([0], [0], label=i, linewidth=2, alpha=0.5,markersize=7,color=colors[i])
           for i in range(10)]
ax.legend(loc='lower right',handles=handles,framealpha=0.9,handletextpad=0.5,handlelength=1.0)

plt.show()

mnist_tsne_05.png

なんか天気予報図のように見えますね。じゃ次は数字画像も表示しましょう。

fig, ax = plt.subplots(figsize=(16,16))

for pos, d, i in zip(grid_jv, x_data.reshape((-1,28,28)), y_data):
  img = Image.fromarray(colorize(d, colors[i], 0.8), 'RGBA').resize((20, 20), Image.ANTIALIAS)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01+pos*0.98,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()

plt.show()

mnist_tsne_06.png

いい感じですね、これで MNIST データセットに何にかあるってことはちょっと把握しました。

単独の数字も可視化してみる

もしある単独の数字だけでしたらどうですか。例えば「6」を見てみましょう。

num = 6
size = 50
n = size ** 2

x_data_n, y_data_n = shuffle(data[target==num], target[target==num], n_samples=n)
x_pca_n = PCA(n_components=50).fit_transform(x_data_n/255)
embeddings_n = TSNE(perplexity=50, random_state=24680, verbose=0).fit_transform(x_pca_n)
embeddings_n = minmax_scale(embeddings_n)
fig, ax = plt.subplots(figsize=(16,16))

for pos, d in zip(embeddings_n, x_data_n.reshape((-1,28,28))):
  img = colorize(d, colors[num], 0.7)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
  ax.add_artist(ab)
  ax.xaxis.set_ticks([])
  ax.yaxis.set_ticks([])

plt.show()

mnist_tsne_07.png

grid_n = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost_n = cdist(grid_n, embeddings_n, 'sqeuclidean').astype('float32')
cost_n *= 1e7 / cost_n.max()
_, col_asses_n, _ = lapjv(cost_n)
grid_jv_n = grid_n[col_asses_n]

fig, ax = plt.subplots(figsize=(16,16))

for pos, d in zip(grid_jv_n, x_data_n.reshape((-1,28,28))):
  img = colorize(d, colors[num], 0.85)
  img = Image.fromarray(img,'RGBA').resize((20, 20), Image.ANTIALIAS)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01 + pos * 0.98,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()

plt.show()

mnist_tsne_08.png

全体の数字を単独で

最後は全体の数字をそれぞれのグリットに配置してみましょう。

mnist_tsne_09.png

まとめ

これでようやく MNIST のちょっと感性的なイメージを把握できまして。データ検証には、T-SNE Gridがすごく有力なツールですね、これからももっと活用したいと思います。

本記事に使ったフールソースコードが Colab に載せます

https://colab.research.google.com/github/stwind/notebooks/blob/master/mnist-tsne.ipynb

おまけ

ちなみに、lapjv で計算した grid と間違い embedding を組み合わせねば、面白い結果が出ます

mnist_tsne_10.png

どこか generative art の感じがする。

参考文献

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした