LoginSignup
6
8

More than 3 years have passed since last update.

t-SNE & UMAP 試行・分けられた分布の中にある数字の形の傾向を調べる

Last updated at Posted at 2020-04-04

もし間違った認識をしている点がありましたら、バンバン編集リクエスト・コメントなどご指摘お願い致します。m(_ _)m

参照: https://qiita.com/cheerfularge/items/27a55ebde4a671880666

試行:

  • t-SNE、UMAPを試行。MNIST 8x8手書き数字画像1797枚を使用。
  • それぞれ分けられた分布の中にある数字の形の傾向を調べる。

結果:

  • 1の画像は3つの分布に分かれて、それぞれ、かぎがない棒だけの1の形状のタイプ、上にかぎのある1の形状のタイプ、左に傾けて崩して書く1の形状のタイプ、に分かれる様子が見られる。1の中でも傾向によりさらに分割され、良い分別の性能。
  • 3の画像は1つの分布に集まるが、その中でも幅が細く書かれた3は右側へ、幅いっぱいに広げて書かれた3は左側へ、傾向によりさらに分割され、良い分別の性能。その他の数字でも同様の傾向が見られる。

Lib

import numpy as np
import matplotlib.pyplot as plt

Load

# sec: load

from sklearn.datasets import load_digits
digits = load_digits()
print(digits.data.shape)
print(digits.target.shape)
print(digits.images.shape)
(1797, 64)
(1797,)
(1797, 8, 8)
# sec: MNISTの画像をグリッド状に描画

def draw_digits(i_list, n_grid=(10, 10), annosize=10, figsize=(12, 12)):
    # assume: i_listは画像配列の番号をリストに入れたもの、annosize=Noneで画像列番と正解ラベルを非表示
    images = digits.images
    labels =digits.target

    fig = None
    i_ax = 0
    for i_img in i_list:

        if fig is None or i_ax >= n_grid[0] * n_grid[1]:
            fig = plt.figure(figsize=figsize)
            plt.subplots_adjust(hspace=0.02, wspace=0)
            i_ax = 0
        i_ax += 1

        ax = fig.add_subplot(n_grid[0], n_grid[1], i_ax)
        if i_img is None:
            ax.axis('off')
            continue

        ax.imshow(images[i_img], cmap='gray', interpolation='none')
        if annosize is not None: # if: 画像列番と正解ラベルを追記
            ax.annotate("%d" % i_img, 
                xy=(0, 0.98), xycoords='axes fraction', ha='left', va='top', color='y', fontsize=annosize)
            ax.annotate("L:%d" % labels[i_img], 
                xy=(1, 0.98), xycoords='axes fraction', ha='right', va='top', color='c', fontsize=annosize)
        ax.axis('off')

    plt.show()

draw_digits(list(range(24)))

output_4_0.png

t-SNEを実行

from sklearn.manifold import TSNE
# sec: 実行

model = TSNE(n_components=2) # 2軸へ次元圧縮
res = model.fit_transform(digits.data)
print(res.shape)
(1797, 2)
# sec: 結果の描画

import matplotlib.cm as cm
plt.figure(figsize=(12, 12))
plt.scatter(res[:,0], res[:,1], s=3, c=digits.target, cmap=cm.tab10)
plt.colorbar()
plt.grid()
plt.show()

output_8_0.png

結果のグループを調べる

# sec: 結果を保存 毎回実行で変わる為

import pickle
with open("./results/2003 t-SNE/res-tsne-1.pickle", 'wb') as file:
    pickle.dump(res, file)
# sec: 結果を読み込み 前回の続きから

import pickle
with open("./results/2003 t-SNE/res-tsne-1.pickle", 'rb') as file:
    res = pickle.load(file)

7の中の3を調べる

# sec: 7の中の3を調べる →1118番目のデータ ⇒確かに似ているもの同士が集まっている、解像度が低すぎる

import matplotlib.cm as cm
plt.scatter(res[:, 0], res[:, 1], s=10, c=digits.target, cmap=cm.tab10)
plt.axis([-50, -30, -30, -10]); plt.grid(); plt.show()

output_13_0.png

i_list = np.where((-37.5 < res[:, 0]) & (res[:, 0] < -32.5) & (-25 < res[:, 1]) & (res[:, 1] < -20))[0]
i_list
array([  86,  283,  317,  374,  698,  707,  727,  740,  754,  758, 1118,
       1399, 1467], dtype=int64)
draw_digits(i_list)

output_15_0.png

7の集団の下側を調べる

# sec: 7の集団の下側を調べる ⇒似ているか。やや丸みのない7が集まっている模様。

plt.scatter(res[:, 0], res[:, 1], s=10, c=digits.target, cmap=cm.tab10)
plt.axis([-60, -40, -30, -0]); plt.grid(); plt.show()

output_17_0.png

i_list = np.where((-55 < res[:, 0]) & (res[:, 0] < -50) & (-20 < res[:, 1]) & (res[:, 1] < -10))[0]
i_list
array([ 191,  216,  240,  819,  828,  870, 1373, 1587, 1595, 1604, 1627,
       1635, 1657], dtype=int64)
draw_digits(i_list)

output_19_0.png

7の集団の横方向の画像を並べて変化を見る

# sec: 7の集団の横方向の画像を並べて変化を見る

i_list = np.where((-60 < res[:, 0]) & (res[:, 0] < -30) & (-20 < res[:, 1]) & (res[:, 1] < -18))[0]
i_list
array([  17,   61,  118,  137,  147,  174,  182,  191,  216,  240,  300,
        350,  364,  559,  803,  820,  837,  857,  884,  888,  949, 1304,
       1330, 1381, 1422, 1432, 1442, 1476, 1496, 1527, 1748, 1753],
      dtype=int64)
i_list = i_list[np.argsort(res[i_list, 0])] # x方向でソート
i_list
array([ 191,  216,  240, 1748,  364, 1753,  350,  147,  884,  137,  182,
        803,  888,  857,  837,  118,  559,   17,   61, 1381,  820, 1304,
       1330, 1496,  174, 1476,  300, 1422, 1527, 1442, 1432,  949],
      dtype=int64)
draw_digits(i_list)

output_23_0.png

分布の位置に画像を当てはめて格子状に表示

def draw_digits_at_tsne(res, x_min, x_max, y_min, y_max, 
    n_grid=(15, 15), annosize=8, figsize=(12, 12)):

    x_pitch = (x_max - x_min) / n_grid[1]
    y_pitch = (y_max - y_min) / n_grid[0]
    i_draw_list = []
    for i_y in range(n_grid[0]):
        y_i = y_max - y_pitch * i_y - y_pitch/2 # 格子中央点
        for i_x in range(n_grid[1]):
            x_i = x_min + x_pitch * i_x + x_pitch/2 # 格子中央点

            i_list = np.where((x_i-x_pitch/2 < res[:, 0]) & (res[:, 0] < x_i+x_pitch/2) & \
                              (y_i-y_pitch/2 < res[:, 1]) & (res[:, 1] < y_i+y_pitch/2))[0] # 格子内の点を集める
            res_i = res[i_list, :]
            if len(res_i) == 0: # if: 格子内に点なし
                i_draw_list.append(None)
                continue

            r2_i = ((res_i[:, 0] - x_i) / x_pitch)**2 + ((res_i[:, 1] - y_i) / y_pitch)**2 # 格子中央と点との距離
            i_min = i_list[np.argmin(r2_i)] # 格子中央に最も近い点
            i_draw_list.append(i_min)

    plt.figure(figsize=(6, 6))
    plt.scatter(res[:, 0], res[:, 1], s=10, c=digits.target, cmap=cm.tab10) # 指定範囲内の点の分布を描画
    plt.axis([x_min, x_max, y_min, y_max]); plt.grid(); plt.show()

    draw_digits(i_draw_list, n_grid=n_grid, annosize=annosize, figsize=figsize)

draw_digits_at_tsne(res, -55, -20, -30, -5)

output_25_0.png

output_25_1.png

各数字の集団の分布状況を描画

# sec: 各数字の集団の分布状況を描画 

draw_digits_at_tsne(res, -35, -10, -70, -45) # 0周り

output_27_0.png

output_27_1.png

draw_digits_at_tsne(res, 0, 30, -20, 10) # 1周り
draw_digits_at_tsne(res, 40, 60, 10, 30) # 1周り

output_28_0.png

output_28_1.png

output_28_2.png

output_28_3.png

draw_digits_at_tsne(res, 20, 50, 20, 50) # 2周り

output_29_0.png

output_29_1.png

draw_digits_at_tsne(res, -20, 15, 30, 60) # 3周り

output_30_0.png

output_30_1.png

draw_digits_at_tsne(res, -5, 20, -50, -25) # 4周り

output_31_0.png

output_31_1.png

draw_digits_at_tsne(res, -35, -10, -10, 25) # 5周り

output_32_0.png

output_32_1.png

draw_digits_at_tsne(res, 40, 60, -30, 0) # 6周り

output_33_0.png

output_33_1.png

draw_digits_at_tsne(res, -55, -20, -30, -5) # 7周り

output_34_0.png

output_34_1.png

draw_digits_at_tsne(res, -10, 30, 0, 30) # 8周り

output_35_0.png

output_35_1.png

draw_digits_at_tsne(res, -40, -10, 25, 45) # 9周り

output_36_0.png

output_36_1.png

やや広めに分布状況を描画

# sec: やや広めに分布状況を描画 

draw_digits_at_tsne(res, -40, 0, 0, 60) # 左上半面

output_38_0.png

output_38_1.png

draw_digits_at_tsne(res, 0, 60, 0, 60) # 右上半面

output_39_0.png

output_39_1.png

draw_digits_at_tsne(res, -60, 0, -70, 0) # 左下半面

output_40_0.png

output_40_1.png

draw_digits_at_tsne(res, 0, 60, -60, 0) # 右下半面

output_41_0.png

output_41_1.png

全体の分布状況を描画

draw_digits_at_tsne(res, -60, 60, -70, 60, n_grid=(50, 50), annosize=None)

output_43_0.png

output_43_1.png

UMAPを試す

t-SNEの試行の中で作った描画用の関数を使用。

!pip3 install umap-learn
Collecting umap-learn
  Downloading https://files.pythonhosted.org/packages/ad/92/36bac74962b424870026cb0b42cec3d5b6f4afa37d81818475d8762f9255/umap-learn-0.3.10.tar.gz (40kB)
Requirement already satisfied: numpy>=1.13 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from umap-learn) (1.16.5+mkl)
Requirement already satisfied: scikit-learn>=0.16 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from umap-learn) (0.21.3)
Requirement already satisfied: scipy>=0.19 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from umap-learn) (1.3.1)
Requirement already satisfied: numba>=0.37 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from umap-learn) (0.45.1)
Requirement already satisfied: joblib>=0.11 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from scikit-learn>=0.16->umap-learn) (0.13.2)
Requirement already satisfied: llvmlite>=0.29.0 in ...\software\wpy64-3741\python-3.7.4.amd64\lib\site-packages (from numba>=0.37->umap-learn) (0.29.0)
Building wheels for collected packages: umap-learn
  Building wheel for umap-learn (setup.py): started
  Building wheel for umap-learn (setup.py): finished with status 'done'
  Created wheel for umap-learn: filename=umap_learn-0.3.10-cp37-none-any.whl size=38886 sha256=9d57f567f0ec496157a12d90e1212d52ca01cc14510a12335a06eeef5f0766c2
  Stored in directory: ...\AppData\Local\pip\Cache\wheels\d0\f8\d5\8e3af3ee957feb9b403a060ebe72f7561887fef9dea658326e
Successfully built umap-learn
Installing collected packages: umap-learn
Successfully installed umap-learn-0.3.10


WARNING: You are using pip version 19.2.3, however version 20.0.2 is available.
You should consider upgrading via the 'python -m pip install --upgrade pip' command.

UMAPを実行

import umap
from scipy.sparse.csgraph import connected_components
# 公式GitHubには書いてあるのですが、↑を書かないとエラーが出てしまいます。
# sec: 実行

res_umap = umap.UMAP().fit_transform(digits.data)
print(res_umap.shape)
(1797, 2)
# sec: 結果の描画

import matplotlib.cm as cm
plt.figure(figsize=(12, 12))
plt.scatter(res_umap[:,0], res_umap[:,1], s=3, c=digits.target, cmap=cm.tab10)
plt.colorbar()
plt.grid()
plt.show()

output_49_0.png

結果のグループを調べる

# sec: 結果を保存 毎回実行で変わる為

import pickle
with open("./results/2003 t-SNE/res-umap-1.pickle", 'wb') as file:
    pickle.dump(res_umap, file)
# sec: 結果を読み込み 前回の続きから

import pickle
with open("./results/2003 t-SNE/res-umap-1.pickle", 'rb') as file:
    res_umap = pickle.load(file)

各数字の集団の分布状況を描画

# sec: 各場所の数字の分布状況を描画

draw_digits_at_tsne(res_umap, 11.5, 14, 0, 2.5) # 0周り

output_54_0.png

output_54_1.png

draw_digits_at_tsne(res_umap, 0, 1.5, 3, 7) # 1周り
draw_digits_at_tsne(res_umap, -7, -6, -18, -17) # 1周り

output_55_0.png

output_55_1.png

output_55_2.png

output_55_3.png

draw_digits_at_tsne(res_umap, 3, 6.5, -3, -1) # 2周り

output_56_0.png

output_56_1.png

draw_digits_at_tsne(res_umap, -7.5, -5, 0, 4) # 3周り

output_57_0.png

output_57_1.png

draw_digits_at_tsne(res_umap, 3.5, 6.5, 8, 11.5) # 4周り

output_58_0.png

output_58_1.png

draw_digits_at_tsne(res_umap, -9.5, -7.5, -8, -4) # 5周り

output_59_0.png

output_59_1.png

draw_digits_at_tsne(res_umap, 0, 3, -14, -11.5) # 6周り

output_60_0.png

output_60_1.png

draw_digits_at_tsne(res_umap, -6.5, -3, 5, 10) # 7周り

output_61_0.png

output_61_1.png

draw_digits_at_tsne(res_umap, -3.5, 0, 2, 4.5) # 8周り

output_62_0.png

output_62_1.png

draw_digits_at_tsne(res_umap, -7, -4.5, -3, -1) # 9周り

output_63_0.png

output_63_1.png

全体の分布状況を描画

draw_digits_at_tsne(res_umap, -10, 15, -17, 13, n_grid=(50, 50), annosize=None)

output_65_0.png

output_65_1.png

6
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
8