Help us understand the problem. What is going on with this article?

K近傍法(多クラス分類)

More than 3 years have passed since last update.

K近傍法とは

KNN(K Nearest Neighbor)。クラス判別用の手法。
学習データをベクトル空間上にプロットしておき、未知のデータが得られたら、そこから距離が近い順に任意のK個を取得し、多数決でデータが属するクラスを推定する。

例えば下図の場合、クラス判別の流れは以下となる。
1 既知のデータ(学習データ)を黄色と紫の丸としてプロットしておく。
2 Kの数を決めておく。K=3とか。
3 未知のデータとして赤い星が得られたら、近い点から3つ取得する。
4 その3つのクラスの多数決で、属するクラスを推定。
今回は、未知の赤い星はClass Bに属すると推定する。

スクリーンショット 2016-05-04 3.33.02.png

※Kの数次第で結果が変わるので注意。K=6にすると、赤い星はClass Aと判定される。

利用データ用意

sklearnでirisのデータセットを用意。

get_iris_dataset.py
from sklearn.datasets import load_iris
iris= load_iris() # irisデータ取得
X = iris.data     # 説明変数(クラス推定用変数)
Y = iris.target   # 目的変数(クラス値)

# irisのデータをDataFrameに変換
iris_data = DataFrame(X, columns=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'])
iris_target = DataFrame(Y, columns=['Species'])

# iris_targetが0〜2の値で分かりづらいので、あやめの名前に変換
def flower(num):
"""名前変換用関数"""
    if num == 0:
        return 'Setosa'
    elif num == 1:
        return 'Veriscolour'
    else:
        return 'Virginica'

iris_target['Species'] = iris_target['Species'].apply(flower)
iris = pd.concat([iris_data, iris_target], axis=1)

データの概要

describe_iris.py
iris.head()

スクリーンショット 2016-05-04 3.45.38.png

スクリーンショット 2016-05-04 3.50.37.png の長さと幅のデータ

seaboanでpairplotして、クラス別に概要を見てみる

desplay_each_data.py
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

sns.pairplot(iris, hue = 'Species', size =2) # hue:指定したデータで分割

スクリーンショット 2016-05-04 3.54.59.png

Setosa[青の点]は分類しやすそう。Veriscolour[緑の点]とVirginica[赤の点]はPetal Lengthあたりで分類できるかも? くらいの印象。

やってみる

sklearnでKNNを実行。

do_knn.py
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import train_test_split # trainとtest分割用

# train用とtest用のデータ用意。test_sizeでテスト用データの割合を指定。random_stateはseed値を適当にセット。
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.4, random_state=3) 

knn = KNeighborsClassifier(n_neighbors=6) # インスタンス生成。n_neighbors:Kの数
knn.fit(X_train, Y_train)                 # モデル作成実行
Y_pred = knn.predict(X_test)              # 予測実行

# 精度確認用のライブラリインポートと実行
from sklearn import metrics
metrics.accuracy_score(Y_test, Y_pred)    # 予測精度計測
> 0.94999999999999996

95%くらいの精度。

Kの数で精度が変わる。→Kをどれにするのがいいのか分からないので、とりあえずKを色々変えて精度の変化グラフを書いてみる。

create_graph_knn_accracy_change_k.py
accuracy = []
for k in range(1, 90):
    knn = KNeighborsClassifier(n_neighbors=k) # インスタンス生成。
    knn.fit(X_train, Y_train)                 # モデル作成実行
    Y_pred = knn.predict(X_test)              # 予測実行
    accuracy.append(metrics.accuracy_score(Y_test, Y_pred)) # 精度格納

plt.plot(k_range, accuracy)

90回回してみた結果

スクリーンショット 2016-05-04 4.10.41.png

K=3?くらいで十分そう。30を超えると精度悪くなってる。
今回、学習用データが90件しかないので、一つのクラスあたり30個ずつくらいしか学習データがない。
Kの数が30を超えると、正解クラスのデータが全部含まれてしまっていたら、あとは異なるクラスしか最近傍で拾えなくなるので、精度はどんどん悪くなっていくとんだと予想。

yshi12
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした