K近傍法とは
KNN(K Nearest Neighbor)。クラス判別用の手法。
学習データをベクトル空間上にプロットしておき、未知のデータが得られたら、そこから距離が近い順に任意のK個を取得し、多数決でデータが属するクラスを推定する。
例えば下図の場合、クラス判別の流れは以下となる。
1 既知のデータ(学習データ)を黄色と紫の丸としてプロットしておく。
2 Kの数を決めておく。K=3とか。
3 未知のデータとして赤い星が得られたら、近い点から3つ取得する。
4 その3つのクラスの多数決で、属するクラスを推定。
今回は、未知の赤い星はClass Bに属すると推定する。
※Kの数次第で結果が変わるので注意。K=6にすると、赤い星はClass Aと判定される。
利用データ用意
sklearnでirisのデータセットを用意。
from sklearn.datasets import load_iris
iris= load_iris() # irisデータ取得
X = iris.data # 説明変数(クラス推定用変数)
Y = iris.target # 目的変数(クラス値)
# irisのデータをDataFrameに変換
iris_data = DataFrame(X, columns=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'])
iris_target = DataFrame(Y, columns=['Species'])
# iris_targetが0〜2の値で分かりづらいので、あやめの名前に変換
def flower(num):
"""名前変換用関数"""
if num == 0:
return 'Setosa'
elif num == 1:
return 'Veriscolour'
else:
return 'Virginica'
iris_target['Species'] = iris_target['Species'].apply(flower)
iris = pd.concat([iris_data, iris_target], axis=1)
データの概要
iris.head()
seaboanでpairplotして、クラス別に概要を見てみる
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline
sns.pairplot(iris, hue = 'Species', size =2) # hue:指定したデータで分割
Setosa[青の点]は分類しやすそう。Veriscolour[緑の点]とVirginica[赤の点]はPetal Lengthあたりで分類できるかも? くらいの印象。
やってみる
sklearnでKNNを実行。
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import train_test_split # trainとtest分割用
# train用とtest用のデータ用意。test_sizeでテスト用データの割合を指定。random_stateはseed値を適当にセット。
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.4, random_state=3)
knn = KNeighborsClassifier(n_neighbors=6) # インスタンス生成。n_neighbors:Kの数
knn.fit(X_train, Y_train) # モデル作成実行
Y_pred = knn.predict(X_test) # 予測実行
# 精度確認用のライブラリインポートと実行
from sklearn import metrics
metrics.accuracy_score(Y_test, Y_pred) # 予測精度計測
> 0.94999999999999996
95%くらいの精度。
Kの数で精度が変わる。→Kをどれにするのがいいのか分からないので、とりあえずKを色々変えて精度の変化グラフを書いてみる。
accuracy = []
for k in range(1, 90):
knn = KNeighborsClassifier(n_neighbors=k) # インスタンス生成。
knn.fit(X_train, Y_train) # モデル作成実行
Y_pred = knn.predict(X_test) # 予測実行
accuracy.append(metrics.accuracy_score(Y_test, Y_pred)) # 精度格納
plt.plot(k_range, accuracy)
90回回してみた結果
K=3?くらいで十分そう。30を超えると精度悪くなってる。
今回、学習用データが90件しかないので、一つのクラスあたり30個ずつくらいしか学習データがない。
Kの数が30を超えると、正解クラスのデータが全部含まれてしまっていたら、あとは異なるクラスしか最近傍で拾えなくなるので、精度はどんどん悪くなっていくとんだと予想。