21
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

k-Nearest Neighbors(k近傍法)で手書き数字の認識

Last updated at Posted at 2018-04-12

機械学習初心者です。東大松尾研のDeep Learning基礎講座をもとに勉強した際のノートです。第3回の内容に当たります。資料ではk-NNそのものの実装は載っていなかったので、自分で実装してみました。

##参考資料
東京大学松尾研究室
Deep Learning基礎講座演習コンテンツ 公開ページ
Qiita内記事
手書き数字をpythonでもてあそぶ その1
手書き数字をpythonでもてあそぶ その2(識別する)
k近傍法とk平均法の違いと詳細.
K近傍法(多クラス分類)

##画像を読んでみる
データの元ネタはこれ。サンプルをシャッフルして読み込みます。

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.datasets import fetch_mldata

import numpy as np
import matplotlib.pyplot as plt

mnist = fetch_mldata('MNIST original')
mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=114514)

# データの件数
print(len(mnist_X))
# 実際の数字
print(mnist_y[0])
# 画像データの形式
print(mnist_X[0].shape)
# 手書きの数字の画像データ(28x28の行列に変形)
print(mnist_X[0].reshape(28,28))
# 画像を表示してみる
plt.imshow(mnist_X[0].reshape((28, 28)), cmap='gray')
plt.show()

7万件のデータで、それぞれの画像データが28×28ピクセルの256段階のグレースケールとして格納されています(実際の形式は28×28=784次元のベクトル)。シードであるrandom_stateを114514としてシャッフルした場合、index=0のデータは「4」を表す画像のようです。行列に変形してみるとどういう画像だったかわかりやすいです。

70000
4.0
(784,)
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  51 238 254 128   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  11  83  82   0   0   0   0  54 233 253 253  11   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  47 195 253 244  16   0   0  30 222 253 253 168   5   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  17 153 246 253 165  36   0   0  41 234 253 108   8   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  82 206 253 227  77  35   0   0  35 199 253 163  10   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  53 211 253 225 104   0   0   0  11 108 253 232  92   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  81 240 253 225 105   0   0  45  85 167 253 253 175  23   0   0   0   0   0   0]
 [  0   0   0   0   0   0  10 115 242 253 225  79  98  98 213 238 253 253 253 201   9   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  77 253 253 253 247 238 253 253 253 254 253 253 253  83   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0 128 249 253 253 253 253 253 253 195 195 138 253 253 149   4   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0 186 254 254 255 246 207  23   0   0   6 198 254 254  85   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   6 109 166 109  31   0   0   0   0 133 254 253 163  36   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  34 244 254 221  37   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  19 235 253 246  36   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0 184 253 253  98   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 158 245 253 103  29   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  21 177 244 253 164  25   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  61 253 253 253 144   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  32 225 253 253 144   6   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   3  46 132  28   1   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]

こんな画像。
knn_1.png

##コサイン類似度
k-NNでは画像間の距離を取って比較しますが、ここでは距離関数としてコサイン類似度を使います。

$$\cos(\vec{a},\vec{b})=\frac{\vec{a}\cdot\vec{b}}{|\vec{a}||\vec{b}|}=\frac{a_1b_1+\cdots+a_nb_n}{\sqrt{a_1^2+\cdots+a_n^2}\sqrt{b_1^2+\cdots+b_n^2}}$$

参考:コサイン類似度

-1~1の間で表される値で、分子は内積、分母はノルムなのでベクトルを用いて簡潔に表せます(あとノルムが標準化の役割も果たしているので、生データを標準化しない場合でのバイアスの問題もクリアしているのかなと思います)。ノルムというと馴染みがないかもしれませんが、2次元ベクトルの場合は三平方の定理です。np.linalg.norm()で計算できます。本当に計算できているか確認してみます。

>>> np.linalg.norm([3,4],ord=2)
5.0
>>> v = np.array([3,4])
>>> np.sum(v**2)**0.5
5.0

参考:ノルムの意味とL1,L2,L∞ノルム

ベクトルの次元を増やしても上の計算で一致します。コサイン類似度というと「ん?」となりましたが、平均で引かない相関係数です。これは目からウロコでした。

参考:相関係数・COS類似度

ちなみに、ノルム割ったベクトルは長さが1になるので単位ベクトルとなります。わかりやすいように2次元で考えますが、片方のベクトルをx軸のプラス方向においたとき、もう片方のベクトルは単位円の円周上にあります。つまり片方のベクトルを

$$\frac{\vec{a}}{|\vec{a}|} = \vec{e_a} = (1,0)$$

とすると、もう片方のベクトルは、

$$\frac{\vec{b}}{|\vec{b}|} = \vec{e_b} = (\cos\theta,\sin\theta)$$

です。この2つのベクトルのユークリッド距離(≠コサイン類似度)は以下のようになります。要はこれを最小化すればいいわけです。

$$L = \sqrt{(\cos\theta-1)^2 + \sin^2\theta}$$

$L>0$なのでルートを2乗しても順序関係は変わりません。

\begin{align}
L^2 &=  (\cos\theta-1)^2 + \sin^2\theta \\
 &= \cos^2\theta -2\cos\theta + 1 + \sin^2\theta \\
 &= 2(1-\cos\theta)\qquad(\because\sin^2\theta+\cos^2\theta=1)
\end{align}

$$\frac{L^2}{2}-1 = -\cos\theta = -\vec{e_a}\cdot\vec{e_b} = -\cos(\vec{a},\vec{b})$$

以上より、ユークリッド距離の最小化=コサイン類似度の最大化と考えることができます。証明は省きますが、3次元以上も同じのはずです。ただし、今回は画像データのベクトルの値が全て正なので特に考える必要はありませんが、負の値があってコサインが-1の場合は、ベクトルが真逆に向いているので、それをもってこの2つの画像が全然似ていないという結論に行くのはちょっと疑問が残るとは思います(画像を白黒反転させた場合はどうなるんだろう)。

###ユークリッド距離とコサイン類似度の補足
さすがに「2次元で成り立つから多次元でも成り立ちます」は乱暴すぎたので、数値計算で確かめてみます。5次元の例。

5次元の極座標によるユークリッド距離を解析的に表すことはできますが、おそらくそこそこ複雑な式になります(ちなみに今回のk-NNの例では784次元です!)。なのでモンテカルロ法で、乱数を用いて単位ベクトルを多数生成し、ユークリッド距離とコサイン類似度の関係をプロットしてみます。

import numpy as np
import matplotlib.pyplot as plt

# 単位ベクトルをn個作成
n = 10000
# ランダムな5次元単位ベクトルを作成
np.random.seed(114514)
v = np.random.uniform(low = -2.0, high = 2.0, size = 5*n).reshape(n,5)
v = v / np.linalg.norm(v, ord=2, axis=1)[:, np.newaxis]
# 比較対象の単位ベクトルu
u = np.array([1,0,0,0,0], dtype=float)
# ユークリッド距離
euclid = np.linalg.norm(u-v, ord=2, axis=1)
# コサイン類似度(内積の順番に注意)
cosine = np.dot(v, u)
# 横軸をユークリッド距離、縦軸をコサイン類似度としてプロット
plt.plot(euclid, cosine, "o")
plt.xlabel("Euclidean distance")
plt.ylabel("Cosine similarity")
plt.show()

横軸がユークリッド距離、縦軸がコサイン類似度です。単位ベクトルは1万個生成しています。
knn_2.png

とてもきれいな関係になりました。ユークリッド距離の最小化=コサイン類似度の最大化と考えて良さそうです。784次元も数値計算にかかれば恐れる必要はありません。

##k-Nearest Neighborsの実装
###コード
少々脱線してしまいましたがここからが本番です。まずはサクッと実装してみます。

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.datasets import fetch_mldata

import numpy as np
import matplotlib.pyplot as pltlt


mnist = fetch_mldata('MNIST original')
mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=114514)

# 学習データ8割、テストデータ2割で分割
train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y, test_size=0.2, random_state=114514)
# 学習データをノルムで標準化
norm = np.linalg.norm(train_X, ord=2, axis=1)
normalized_train_X = train_X / norm[:, np.newaxis]

# 学習データから分類する関数
def classify(traX, tray, targetX):
    #traX, tray:学習データ、targetX:判定する画像データ
    result = {}
    # コサイン類似度
    cosine = np.dot(traX, targetX) / np.linalg.norm(targetX, ord=2)
    # k=学習データの数の平方根
    k = np.ceil(np.sqrt(len(tray))).astype(int)
    result["k"] = k
    # 類似度が上位k個のデータのインデックスを選択
    selected_idx = np.argsort(-cosine)[:k]
    result["selected_idx"] = selected_idx
    result["selected_cosine"] = cosine[selected_idx]
    # 選択されたデータのyを数字ごとに数え上げる
    count = np.bincount(tray[selected_idx].astype(int), minlength=10)
    result["bincount"] = count
    # 多数決で判定
    result["predict"] = np.argmax(count)
    return result

print("真の値:", test_y[0])
h = classify(normalized_train_X, train_y, test_X[0])
print("推定値:", h["predict"])
print()
print(h)

後から気づいたのですが、scikit-learnには組み込みのk-NN分類器があってわざわざコードを書かなくていいそうです。なーんだ。

参考:http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
K近傍法(多クラス分類)

###解説
classify関数は「ある画像データ(テストデータ)が与えられたときに、それが何の数字を表すのかを学習データから推定する分類器」です。コサイン類似度が高いものの上位k(=学習データ数の平方根)個を取り、多数決で判定します。例えばk=200で、そのうち1であるものが70、4であるものが100、8であるものが30ならば、多数決により4であると推定されます。k=学習データ数の平方根としたのはこちらの記事を参考にしました(あくまで1つの指標なので平方根にこだわる必要はない)。

参考:k近傍法とk平均法の違いと詳細.

1番目のテストデータを分類器にかけてみた結果はこちら。

真の値: 3.0
推定値: 3

{'k': 237, 'selected_idx': array([44635, 33883, 44541, 27018, 49471, 41585, 1237
6, 52554, 46889,
        9040, 43351, 48433, 22142, 30517, 47699, 44117, 47668, 14076,(中略)
       32678, 38676, 29784, 27183,  7869, 16280, 15185, 30235, 16840,
       42552, 47052, 55205], dtype=int64), 'selected_cosine': array([ 0.78297285
,  0.78003873,  0.77929463,  0.77622351,  0.77575458,
        0.77434365,  0.76963107,  0.76832584,  0.76811599,  0.76795657,(中略)
        0.70350283,  0.70333152,  0.70324435,  0.70322874,  0.7029288 ,
        0.70270324,  0.70263076]), 'bincount': array([  0,   0,   4, 161,   0,
31,   0,   0,  36,   5], dtype=int64), 'predict': 3}

見事分類に成功しました! resultを読み解いていきます。kは7万件あるデータのうち8割を学習データとしたので、56000の平方根を切り上げた237となっています。selected_cosineはコサイン類似度の上位237件、selected_idxはそれの学習データにおけるインデックスを表します。237件のデータを数字別に分類したものがbincountで、多い所をかいつまんで見ると、3と判定されたのが161件、5と判定されたのが31件、8と判定されたのが36件でした。したがって多数決により、推定値は3となります。

###精度を見る
classify関数を使ってテストデータを全て推定してみます。計算に時間がかかる処理を何度も計算したくないので、計算結果をファイルに書き出します。結果保存用にMessagePackを使ってみました。

import umsgpack

fit = {}
fit["true_value"] = []
fit["pred_value"] = []
for i in range(len(test_y)):
    print(i)
    h = classify(normalized_train_X, train_y, test_X[i])
    fit["true_value"].append(int(test_y[i]))
    fit["pred_value"].append(int(h["predict"]))

# MessagePackで保存
with open("knn.msg", "wb") as fp:
    packed = umsgpack.packb(fit)
    fp.write(packed)

CPU計算で5~10分で終わりました。セルフ実装なので遅いのは仕方ないです。精度を測る尺度としてF値を計算します。

from sklearn.metrics import f1_score
import umsgpack

# 結果の読み込み
with open("knn.msg", "rb") as fp:
    fit = umsgpack.unpack(fp)
# 真の値が0~9ごとに計算したF1スコア
print(f1_score(fit["true_value"], fit["pred_value"], average=None))
# average=Noneのときの単純平均
print(f1_score(fit["true_value"], fit["pred_value"], average="macro"))
# average=Noneの結果を、テストデータの真の値ごとのサンプル数で重みづけた加重平均
print(f1_score(fit["true_value"], fit["pred_value"], average="weighted"))

F値についてはこちら。松尾研の講義資料にも載っていますのでそちらもどうぞ。適合率と正解率の調和平均で、モデルの適合度を測る1つの共通バロメーターなのでしょうね。

[ 0.94877049  0.95082994  0.9442029   0.92775665  0.92828364  0.9354973
  0.9582744   0.93546149  0.92067124  0.89531406]
0.934506210721
0.93464549949

セルフ実装した場合、**F値は93.5%**となりました。数字別に見てみると、**コンピューターが一番識別が得意なのが「6」でF値は95.8%、続いて「1」でF値は95.1%。逆に一番識別が苦手なのは「9」でこれだけ9割切っていてF値は89.5%**となりました。

##組み込みのk-NNの場合
ここまで自分でコードを書いてしまいましたが、k-NNの場合scikit-learnで組み込みの分類器があります。そちらも試してみます。

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.datasets import fetch_mldata
from sklearn.neighbors import KNeighborsClassifier

import numpy as np

mnist = fetch_mldata('MNIST original')
mnist_X, mnist_y = shuffle(mnist.data, mnist.target, random_state=114514)

# 学習データ8割、テストデータ2割で分割
train_X, test_X, train_y, test_y = train_test_split(mnist_X, mnist_y, test_size=0.2, random_state=114514)
# 学習データをノルムで標準化
norm = np.linalg.norm(train_X, ord=2, axis=1)
normalized_train_X = train_X / norm[:, np.newaxis]

# 分類器のインスタンスを作成, k=学習データ数の平方根とする
knn = KNeighborsClassifier(n_neighbors = np.ceil(np.sqrt(len(train_y))).astype(int))
# 学習データをフィット
knn.fit(train_X, train_y)
# 予測実行
pred_y = knn.predict(test_X)

# 精度確認
print(f1_score(test_y, pred_y, average=None))
print(f1_score(test_y, pred_y, average="macro"))
print(f1_score(test_y, pred_y, average="weighted"))

たったこれだけ。今までの苦労は何だったのでしょう。実行したらもっとすごくて、組み込みの分類器はほんの数秒で処理が終わってしまいました。思わず「えっ!?」と声を上げてしまいました。近いNeighborを探す際のアルゴリズムが相当最適化されているんだと思います。セルフ実装はブルートフォース()だったので。F値を確認します。

[ 0.94877049  0.95082994  0.9442029   0.92775665  0.92828364  0.9354973
  0.9582744   0.93546149  0.92067124  0.89531406]
0.934506210721
0.93464549949

上から、average=None、="macro", ="weighted"のケースです。結果はセルフ実装と全く同じっぽい?

##まとめ
scikit-learnの組み込み分類器を使いましょうWikipedia先生曰く「最も単純な」機械学習アルゴリズムとのことですが、機械学習のきの字ぐらいは味わえて楽しかったです。
画像でなくてもベクトルを何らかのラベルに分類できればいいので、意外と応用範囲広いのではないかと思います。

21
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?