Python
sklearn

K近傍法(多クラス分類)

More than 1 year has passed since last update.


K近傍法とは

KNN(K Nearest Neighbor)。クラス判別用の手法。

学習データをベクトル空間上にプロットしておき、未知のデータが得られたら、そこから距離が近い順に任意のK個を取得し、多数決でデータが属するクラスを推定する。

例えば下図の場合、クラス判別の流れは以下となる。

1 既知のデータ(学習データ)を黄色と紫の丸としてプロットしておく。

2 Kの数を決めておく。K=3とか。

3 未知のデータとして赤い星が得られたら、近い点から3つ取得する。

4 その3つのクラスの多数決で、属するクラスを推定。

今回は、未知の赤い星はClass Bに属すると推定する。

スクリーンショット 2016-05-04 3.33.02.png

※Kの数次第で結果が変わるので注意。K=6にすると、赤い星はClass Aと判定される。


利用データ用意

sklearnでirisのデータセットを用意。


get_iris_dataset.py

from sklearn.datasets import load_iris

iris= load_iris() # irisデータ取得
X = iris.data # 説明変数(クラス推定用変数)
Y = iris.target # 目的変数(クラス値)

# irisのデータをDataFrameに変換
iris_data = DataFrame(X, columns=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'])
iris_target = DataFrame(Y, columns=['Species'])

# iris_targetが0〜2の値で分かりづらいので、あやめの名前に変換
def flower(num):
"""名前変換用関数"""
if num == 0:
return 'Setosa'
elif num == 1:
return 'Veriscolour'
else:
return 'Virginica'

iris_target['Species'] = iris_target['Species'].apply(flower)
iris = pd.concat([iris_data, iris_target], axis=1)



データの概要


describe_iris.py

iris.head()


スクリーンショット 2016-05-04 3.45.38.png

スクリーンショット 2016-05-04 3.50.37.png の長さと幅のデータ

seaboanでpairplotして、クラス別に概要を見てみる


desplay_each_data.py

import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

sns.pairplot(iris, hue = 'Species', size =2) # hue:指定したデータで分割


スクリーンショット 2016-05-04 3.54.59.png

Setosa[青の点]は分類しやすそう。Veriscolour[緑の点]とVirginica[赤の点]はPetal Lengthあたりで分類できるかも? くらいの印象。


やってみる

sklearnでKNNを実行。


do_knn.py

from sklearn.neighbors import KNeighborsClassifier

from sklearn.cross_validation import train_test_split # trainとtest分割用

# train用とtest用のデータ用意。test_sizeでテスト用データの割合を指定。random_stateはseed値を適当にセット。
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.4, random_state=3)

knn = KNeighborsClassifier(n_neighbors=6) # インスタンス生成。n_neighbors:Kの数
knn.fit(X_train, Y_train) # モデル作成実行
Y_pred = knn.predict(X_test) # 予測実行

# 精度確認用のライブラリインポートと実行
from sklearn import metrics
metrics.accuracy_score(Y_test, Y_pred) # 予測精度計測
> 0.94999999999999996


95%くらいの精度。

Kの数で精度が変わる。→Kをどれにするのがいいのか分からないので、とりあえずKを色々変えて精度の変化グラフを書いてみる。


create_graph_knn_accracy_change_k.py

accuracy = []

for k in range(1, 90):
knn = KNeighborsClassifier(n_neighbors=k) # インスタンス生成。
knn.fit(X_train, Y_train) # モデル作成実行
Y_pred = knn.predict(X_test) # 予測実行
accuracy.append(metrics.accuracy_score(Y_test, Y_pred)) # 精度格納

plt.plot(k_range, accuracy)


90回回してみた結果

スクリーンショット 2016-05-04 4.10.41.png

K=3?くらいで十分そう。30を超えると精度悪くなってる。

今回、学習用データが90件しかないので、一つのクラスあたり30個ずつくらいしか学習データがない。

Kの数が30を超えると、正解クラスのデータが全部含まれてしまっていたら、あとは異なるクラスしか最近傍で拾えなくなるので、精度はどんどん悪くなっていくとんだと予想。