211
187

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

K近傍法(多クラス分類)

Last updated at Posted at 2016-05-03

K近傍法とは

KNN(K Nearest Neighbor)。クラス判別用の手法。
学習データをベクトル空間上にプロットしておき、未知のデータが得られたら、そこから距離が近い順に任意のK個を取得し、多数決でデータが属するクラスを推定する。

例えば下図の場合、クラス判別の流れは以下となる。
1 既知のデータ(学習データ)を黄色と紫の丸としてプロットしておく。
2 Kの数を決めておく。K=3とか。
3 未知のデータとして赤い星が得られたら、近い点から3つ取得する。
4 その3つのクラスの多数決で、属するクラスを推定。
今回は、未知の赤い星はClass Bに属すると推定する。

スクリーンショット 2016-05-04 3.33.02.png

※Kの数次第で結果が変わるので注意。K=6にすると、赤い星はClass Aと判定される。

利用データ用意

sklearnでirisのデータセットを用意。

{get_iris_dataset.py}
from sklearn.datasets import load_iris
iris= load_iris() # irisデータ取得
X = iris.data     # 説明変数(クラス推定用変数)
Y = iris.target   # 目的変数(クラス値)

# irisのデータをDataFrameに変換
iris_data = DataFrame(X, columns=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'])
iris_target = DataFrame(Y, columns=['Species'])

# iris_targetが0〜2の値で分かりづらいので、あやめの名前に変換
def flower(num):
"""名前変換用関数"""
    if num == 0:
        return 'Setosa'
    elif num == 1:
        return 'Veriscolour'
    else:
        return 'Virginica'

iris_target['Species'] = iris_target['Species'].apply(flower)
iris = pd.concat([iris_data, iris_target], axis=1)

データの概要

{describe_iris.py}
iris.head()
スクリーンショット 2016-05-04 3.45.38.png

スクリーンショット 2016-05-04 3.50.37.png の長さと幅のデータ

seaboanでpairplotして、クラス別に概要を見てみる

{desplay_each_data.py}
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

sns.pairplot(iris, hue = 'Species', size =2) # hue:指定したデータで分割
スクリーンショット 2016-05-04 3.54.59.png

Setosa[青の点]は分類しやすそう。Veriscolour[緑の点]とVirginica[赤の点]はPetal Lengthあたりで分類できるかも? くらいの印象。

やってみる

sklearnでKNNを実行。

{do_knn.py}
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import train_test_split # trainとtest分割用

# train用とtest用のデータ用意。test_sizeでテスト用データの割合を指定。random_stateはseed値を適当にセット。
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.4, random_state=3) 

knn = KNeighborsClassifier(n_neighbors=6) # インスタンス生成。n_neighbors:Kの数
knn.fit(X_train, Y_train)                 # モデル作成実行
Y_pred = knn.predict(X_test)              # 予測実行

# 精度確認用のライブラリインポートと実行
from sklearn import metrics
metrics.accuracy_score(Y_test, Y_pred)    # 予測精度計測
> 0.94999999999999996

95%くらいの精度。

Kの数で精度が変わる。→Kをどれにするのがいいのか分からないので、とりあえずKを色々変えて精度の変化グラフを書いてみる。

{create_graph_knn_accracy_change_k.py}
accuracy = []
for k in range(1, 90):
    knn = KNeighborsClassifier(n_neighbors=k) # インスタンス生成。
    knn.fit(X_train, Y_train)                 # モデル作成実行
    Y_pred = knn.predict(X_test)              # 予測実行
    accuracy.append(metrics.accuracy_score(Y_test, Y_pred)) # 精度格納

plt.plot(k_range, accuracy)

90回回してみた結果

スクリーンショット 2016-05-04 4.10.41.png

K=3?くらいで十分そう。30を超えると精度悪くなってる。
今回、学習用データが90件しかないので、一つのクラスあたり30個ずつくらいしか学習データがない。
Kの数が30を超えると、正解クラスのデータが全部含まれてしまっていたら、あとは異なるクラスしか最近傍で拾えなくなるので、精度はどんどん悪くなっていくとんだと予想。

211
187
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
211
187

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?