LoginSignup
6
8

More than 3 years have passed since last update.

TSNE Grid で MNIST を調べてみた

Last updated at Posted at 2019-06-10

はじめに

最近いろんなモデルを MNIST で試してみました。その中でちょっと気にしたものが、MNIST のことが実はあんまに把握してないですね。例えばある数字はどれぐらいの書き方があるのか、書き方の間ではどんな傾向があるのか。これらの印象は実は薄いですから、やっはり感性的に把握したいです。

本記事ではこの素晴らしい記事を参考にしながら、T-SNE gridを利用して MNIST をもっと感性的に調べて行きます。

準備

まずは利用しているモージュルを導入します。

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib import offsetbox, patheffects
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import cdist
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.utils import shuffle
from sklearn.preprocessing import minmax_scale
from lapjv import lapjv
from PIL import Image

rc = {
  'font.family': ['sans-serif'],
  'font.sans-serif': ['Open Sans', 'Arial Unicode MS'],
  'font.size': 12,
  'figure.figsize': (8, 6),
  'grid.linewidth': 0.5,
  'legend.fontsize': 10,
  'legend.frameon': True,
  'legend.framealpha': 0.6,
  'legend.handletextpad': 0.2,
  'lines.linewidth': 1,
  'axes.facecolor': '#fafafa',
  'axes.labelsize': 10,
  'axes.titlesize': 14,
  'axes.linewidth': 0.5,
  'xtick.labelsize': 10,
  'xtick.minor.visible': True,
  'ytick.labelsize': 10,
  'figure.titlesize': 14
}
sns.set('notebook', 'whitegrid', rc=rc)

def colorize(d, color, alpha=1.0):
  rgb = np.dstack((d,d,d)) * color
  return np.dstack((rgb, d * alpha)).astype(np.uint8)

colors = sns.color_palette('tab10')

MNIST データ

そして、MNIST も導入して、ランダムのサンプルを検証してみます。

data, target = fetch_openml('mnist_784', version=1, return_X_y=True)
data = data.astype('float32')
target = target.astype('uint8')

fig, ax = plt.subplots(figsize=(16,8))

size = 28
dim = (15,30)
n = dim[0] * dim[1]
rnd = np.random.permutation(data.shape[0])

img = np.zeros((size * dim[0], size * dim[1], 4),dtype='uint8')

for d, t, i in zip(data[rnd[:n]], target[rnd[:n]], range(n)):
  ix, iy = divmod(i, dim[1])
  img[ix*size:(ix+1)*size,iy*size:(iy+1)*size,:] = colorize(d.reshape(size,size), colors[t], 0.9)

ax.imshow(img)
ax.set_axis_off()
plt.show()

mnist_tsne_01.png

悪くないですね。次は各クラスのキャラクターの画像を平均化して、ホットスポットを検証してみます。

fig, ax = plt.subplots(figsize=(10,3))

img = np.zeros((28,28*10))

for i in range(0,10):
  img[:,i*28:(i+1)*28] = np.mean(data[np.argwhere(target==i)],axis=0).reshape((28,28))

ax.imshow(img, cmap='RdBu_r')
ax.set_axis_off()
plt.show()

mnist_tsne_02.png

T-SNE で2次元にマッピング

そして次は、T-SNE を利用して、スタイル特徴を検出し、平面にマッピングする。

size = 50
n = size * size
x_data, y_data = shuffle(data, target, n_samples=n)
x_pca = PCA(n_components=50).fit_transform(x_data/255)
embeddings = TSNE(perplexity=50, random_state=24680, verbose=2).fit_transform(x_pca)
embeddings = minmax_scale(embeddings)

fig, ax = plt.subplots(figsize=(12,12))

for i in range(10):
  ax.scatter(embeddings[y_data==i,0],embeddings[y_data==i,1],cmap='tab10',marker='o',alpha=0.7,label=i)
  x_,y_ = np.median(embeddings[y_data==i,:],axis=0)
  txt = ax.text(x_,y_,str(i),fontsize=30)
  txt.set_path_effects([
    patheffects.Stroke(linewidth=7, foreground="w"),
    patheffects.Normal()
  ])

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.legend(loc='lower right')

plt.show()

mnist_tsne_03.png

大体のクラスターがわけられますね。じゃそれぞれのデータポイントの画像も表示してみよう。

fig, ax = plt.subplots(figsize=(16,16))

source = zip(embeddings, x_data.reshape((-1,28,28)), y_data)

for pos, d, i in source:
  img = colorize(d, colors[i], 0.5)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

handles = [mlines.Line2D([0], [0], label=i,
                         linewidth=0, marker='o', alpha=0.5,
                         markersize=7,markerfacecolor=colors[i],markeredgewidth=0)
           for i in range(10)]
ax.legend(loc='lower right',handles=handles)
plt.show()

mnist_tsne_04.png

なんか風が吹いてるように見えて面白いけど、数字自体は重ねて見にくいですね。
じゃ次はこの記事が示すように、src-d/lapjvを利用して、数字はグリットに排列してみます。

まずは、各データポイントの移動方向検証します。

grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost = cdist(grid, embeddings, 'sqeuclidean').astype('float32')
cost *= 1e7 / cost.max()

_, col_asses, _ = lapjv(cost)
grid_jv = grid[col_asses]

fig, ax = plt.subplots(figsize=(16, 16))

for start, end, i in zip(embeddings, grid_jv, y_data):
  plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
            color=colors[i], linewidth=0.05,
            head_length=0.005, head_width=0.005, alpha=0.7)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
handles = [mlines.Line2D([0], [0], label=i, linewidth=2, alpha=0.5,markersize=7,color=colors[i])
           for i in range(10)]
ax.legend(loc='lower right',handles=handles,framealpha=0.9,handletextpad=0.5,handlelength=1.0)

plt.show()

mnist_tsne_05.png

なんか天気予報図のように見えますね。じゃ次は数字画像も表示しましょう。

fig, ax = plt.subplots(figsize=(16,16))

for pos, d, i in zip(grid_jv, x_data.reshape((-1,28,28)), y_data):
  img = Image.fromarray(colorize(d, colors[i], 0.8), 'RGBA').resize((20, 20), Image.ANTIALIAS)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01+pos*0.98,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()

plt.show()

mnist_tsne_06.png

いい感じですね、これで MNIST データセットに何にかあるってことはちょっと把握しました。

単独の数字も可視化してみる

もしある単独の数字だけでしたらどうですか。例えば「6」を見てみましょう。

num = 6
size = 50
n = size ** 2

x_data_n, y_data_n = shuffle(data[target==num], target[target==num], n_samples=n)
x_pca_n = PCA(n_components=50).fit_transform(x_data_n/255)
embeddings_n = TSNE(perplexity=50, random_state=24680, verbose=0).fit_transform(x_pca_n)
embeddings_n = minmax_scale(embeddings_n)
fig, ax = plt.subplots(figsize=(16,16))

for pos, d in zip(embeddings_n, x_data_n.reshape((-1,28,28))):
  img = colorize(d, colors[num], 0.7)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
  ax.add_artist(ab)
  ax.xaxis.set_ticks([])
  ax.yaxis.set_ticks([])

plt.show()

mnist_tsne_07.png

grid_n = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost_n = cdist(grid_n, embeddings_n, 'sqeuclidean').astype('float32')
cost_n *= 1e7 / cost_n.max()
_, col_asses_n, _ = lapjv(cost_n)
grid_jv_n = grid_n[col_asses_n]

fig, ax = plt.subplots(figsize=(16,16))

for pos, d in zip(grid_jv_n, x_data_n.reshape((-1,28,28))):
  img = colorize(d, colors[num], 0.85)
  img = Image.fromarray(img,'RGBA').resize((20, 20), Image.ANTIALIAS)
  ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01 + pos * 0.98,frameon=False)
  ax.add_artist(ab)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()

plt.show()

mnist_tsne_08.png

全体の数字を単独で

最後は全体の数字をそれぞれのグリットに配置してみましょう。

mnist_tsne_09.png

まとめ

これでようやく MNIST のちょっと感性的なイメージを把握できまして。データ検証には、T-SNE Gridがすごく有力なツールですね、これからももっと活用したいと思います。

本記事に使ったフールソースコードが Colab に載せます

おまけ

ちなみに、lapjv で計算した grid と間違い embedding を組み合わせねば、面白い結果が出ます

mnist_tsne_10.png

どこか generative art の感じがする。

参考文献

6
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
8