はじめに
最近いろんなモデルを MNIST で試してみました。その中でちょっと気にしたものが、MNIST のことが実はあんまに把握してないですね。例えばある数字はどれぐらいの書き方があるのか、書き方の間ではどんな傾向があるのか。これらの印象は実は薄いですから、やっはり感性的に把握したいです。
本記事ではこの素晴らしい記事を参考にしながら、T-SNE gridを利用して MNIST をもっと感性的に調べて行きます。
準備
まずは利用しているモージュルを導入します。
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib import offsetbox, patheffects
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial.distance import cdist
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.utils import shuffle
from sklearn.preprocessing import minmax_scale
from lapjv import lapjv
from PIL import Image
rc = {
'font.family': ['sans-serif'],
'font.sans-serif': ['Open Sans', 'Arial Unicode MS'],
'font.size': 12,
'figure.figsize': (8, 6),
'grid.linewidth': 0.5,
'legend.fontsize': 10,
'legend.frameon': True,
'legend.framealpha': 0.6,
'legend.handletextpad': 0.2,
'lines.linewidth': 1,
'axes.facecolor': '#fafafa',
'axes.labelsize': 10,
'axes.titlesize': 14,
'axes.linewidth': 0.5,
'xtick.labelsize': 10,
'xtick.minor.visible': True,
'ytick.labelsize': 10,
'figure.titlesize': 14
}
sns.set('notebook', 'whitegrid', rc=rc)
def colorize(d, color, alpha=1.0):
rgb = np.dstack((d,d,d)) * color
return np.dstack((rgb, d * alpha)).astype(np.uint8)
colors = sns.color_palette('tab10')
MNIST データ
そして、MNIST も導入して、ランダムのサンプルを検証してみます。
data, target = fetch_openml('mnist_784', version=1, return_X_y=True)
data = data.astype('float32')
target = target.astype('uint8')
fig, ax = plt.subplots(figsize=(16,8))
size = 28
dim = (15,30)
n = dim[0] * dim[1]
rnd = np.random.permutation(data.shape[0])
img = np.zeros((size * dim[0], size * dim[1], 4),dtype='uint8')
for d, t, i in zip(data[rnd[:n]], target[rnd[:n]], range(n)):
ix, iy = divmod(i, dim[1])
img[ix*size:(ix+1)*size,iy*size:(iy+1)*size,:] = colorize(d.reshape(size,size), colors[t], 0.9)
ax.imshow(img)
ax.set_axis_off()
plt.show()
悪くないですね。次は各クラスのキャラクターの画像を平均化して、ホットスポットを検証してみます。
fig, ax = plt.subplots(figsize=(10,3))
img = np.zeros((28,28*10))
for i in range(0,10):
img[:,i*28:(i+1)*28] = np.mean(data[np.argwhere(target==i)],axis=0).reshape((28,28))
ax.imshow(img, cmap='RdBu_r')
ax.set_axis_off()
plt.show()
T-SNE で2次元にマッピング
そして次は、T-SNE を利用して、スタイル特徴を検出し、平面にマッピングする。
size = 50
n = size * size
x_data, y_data = shuffle(data, target, n_samples=n)
x_pca = PCA(n_components=50).fit_transform(x_data/255)
embeddings = TSNE(perplexity=50, random_state=24680, verbose=2).fit_transform(x_pca)
embeddings = minmax_scale(embeddings)
fig, ax = plt.subplots(figsize=(12,12))
for i in range(10):
ax.scatter(embeddings[y_data==i,0],embeddings[y_data==i,1],cmap='tab10',marker='o',alpha=0.7,label=i)
x_,y_ = np.median(embeddings[y_data==i,:],axis=0)
txt = ax.text(x_,y_,str(i),fontsize=30)
txt.set_path_effects([
patheffects.Stroke(linewidth=7, foreground="w"),
patheffects.Normal()
])
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.legend(loc='lower right')
plt.show()
大体のクラスターがわけられますね。じゃそれぞれのデータポイントの画像も表示してみよう。
fig, ax = plt.subplots(figsize=(16,16))
source = zip(embeddings, x_data.reshape((-1,28,28)), y_data)
for pos, d, i in source:
img = colorize(d, colors[i], 0.5)
ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
ax.add_artist(ab)
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
handles = [mlines.Line2D([0], [0], label=i,
linewidth=0, marker='o', alpha=0.5,
markersize=7,markerfacecolor=colors[i],markeredgewidth=0)
for i in range(10)]
ax.legend(loc='lower right',handles=handles)
plt.show()
なんか風が吹いてるように見えて面白いけど、数字自体は重ねて見にくいですね。
じゃ次はこの記事が示すように、src-d/lapjvを利用して、数字はグリットに排列してみます。
まずは、各データポイントの移動方向検証します。
grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost = cdist(grid, embeddings, 'sqeuclidean').astype('float32')
cost *= 1e7 / cost.max()
_, col_asses, _ = lapjv(cost)
grid_jv = grid[col_asses]
fig, ax = plt.subplots(figsize=(16, 16))
for start, end, i in zip(embeddings, grid_jv, y_data):
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
color=colors[i], linewidth=0.05,
head_length=0.005, head_width=0.005, alpha=0.7)
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
handles = [mlines.Line2D([0], [0], label=i, linewidth=2, alpha=0.5,markersize=7,color=colors[i])
for i in range(10)]
ax.legend(loc='lower right',handles=handles,framealpha=0.9,handletextpad=0.5,handlelength=1.0)
plt.show()
なんか天気予報図のように見えますね。じゃ次は数字画像も表示しましょう。
fig, ax = plt.subplots(figsize=(16,16))
for pos, d, i in zip(grid_jv, x_data.reshape((-1,28,28)), y_data):
img = Image.fromarray(colorize(d, colors[i], 0.8), 'RGBA').resize((20, 20), Image.ANTIALIAS)
ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01+pos*0.98,frameon=False)
ax.add_artist(ab)
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()
plt.show()
いい感じですね、これで MNIST データセットに何にかあるってことはちょっと把握しました。
単独の数字も可視化してみる
もしある単独の数字だけでしたらどうですか。例えば「6」を見てみましょう。
num = 6
size = 50
n = size ** 2
x_data_n, y_data_n = shuffle(data[target==num], target[target==num], n_samples=n)
x_pca_n = PCA(n_components=50).fit_transform(x_data_n/255)
embeddings_n = TSNE(perplexity=50, random_state=24680, verbose=0).fit_transform(x_pca_n)
embeddings_n = minmax_scale(embeddings_n)
fig, ax = plt.subplots(figsize=(16,16))
for pos, d in zip(embeddings_n, x_data_n.reshape((-1,28,28))):
img = colorize(d, colors[num], 0.7)
ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.03 + pos * 0.94,frameon=False)
ax.add_artist(ab)
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
plt.show()
grid_n = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost_n = cdist(grid_n, embeddings_n, 'sqeuclidean').astype('float32')
cost_n *= 1e7 / cost_n.max()
_, col_asses_n, _ = lapjv(cost_n)
grid_jv_n = grid_n[col_asses_n]
fig, ax = plt.subplots(figsize=(16,16))
for pos, d in zip(grid_jv_n, x_data_n.reshape((-1,28,28))):
img = colorize(d, colors[num], 0.85)
img = Image.fromarray(img,'RGBA').resize((20, 20), Image.ANTIALIAS)
ab = offsetbox.AnnotationBbox(offsetbox.OffsetImage(img),0.01 + pos * 0.98,frameon=False)
ax.add_artist(ab)
ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])
ax.set_axis_off()
plt.show()
全体の数字を単独で
最後は全体の数字をそれぞれのグリットに配置してみましょう。
まとめ
これでようやく MNIST のちょっと感性的なイメージを把握できまして。データ検証には、T-SNE Gridがすごく有力なツールですね、これからももっと活用したいと思います。
本記事に使ったフールソースコードが Colab に載せます
おまけ
ちなみに、lapjv で計算した grid と間違い embedding を組み合わせねば、面白い結果が出ます
どこか generative art の感じがする。