LoginSignup
13
15

More than 5 years have passed since last update.

機械学習における交差検証法の実装(備忘録)

Last updated at Posted at 2017-09-21

はじめに

ゼミの後輩たちがライブラリを用いて機械学習を使い出しました.(後輩たちが目を真っ赤にしながら毎晩頑張ってます)
機械学習を実装しモデル化を行い,あとは交差検証を行い,精度を出すだけですが,躓いている人続出だったので,残しておきます.
というか,ライブラリがあるので,私は英語の文献を読んで自分で解決して欲しいですが...
正解は英語のサイトの方が多いです!!

環境

Python3.6.2
MacOS Sierra 10.12.6
Pychram 2017.2

実装

ソース

main.py
# -*- coding: utf-8 -*-

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import LeaveOneOut
import numpy as np
import csv


# CSVから特徴量読み込み
# K近傍法の実装
# 交差検証法にはLOOCV(Leave-One-Out Cross Validation)
def main():
    K = 3  # K-NNのKの値
    correct_answer_count = 0  # 正解数初期化
    x = []  # 初期化
    y = []  # 初期化
    f = open("data/data_a.csv", 'r')  # fileのopen
    data = csv.reader(f)  # ファイルの読み込み(str型)
#------------------------------1---------------------------------
    for row in data:
        x1 = []
        y.append(int(row[0]))  # ラベルを取得
        for i in range(1, len(row)):  # 1列目から最後の列まで
            x1.append(float(row[i]))
        x.append(x1)
    x = np.array(x)  # xをnumpy行列に変換
    y = np.array(y)  # yをnumpy行列に変換
#------------------------------2---------------------------------

    loo = LeaveOneOut()  # LOOCVのインスタンス生成

    entire_count = loo.get_n_splits(x)  # テスト回数取得(csvファイルの行数)

    neigh = KNeighborsClassifier(n_neighbors=K)  # K-NNのインスタンス生成

    for train_index, test_index in loo.split(x):  # loo.split(x)でいい感じに,sliceしてくれる
        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        neigh.fit(x_train, y_train)  # 学習させる
        result = neigh.predict(x_test)  # テストデータからラベルを予測する
        if result == y_test:  # ラベルと元々のラベルが一致していれば+1
            correct_answer_count += 1

    rate = (float(correct_answer_count) / float(entire_count))  # 正解率を計算
    print(str(rate))  # 正解率を出力


if __name__ == '__main__':
    main()

解説

所詮ライブラリ使っているので,解説も何もないですが...
ソースコードの1から2までが,学習させる特徴ベクトルへ変換しています.
具体的には,
- ラベル=[1,1,...,2]
- 特徴量=[[0.1,0.2,...,0,8],[0.1,0.2,...,0,8],...,[0.1,0.2,...,0,8]]
みたいな形にしています.
numpy行列に変換したのは,参考サイトの流用です...

ソースコードの2より下が,学習と予測をLOOCV(Leave-one-out Cross Validation)という交差検証法を用いて精度を求めています.

終わりに

棚に上げますが,ライブラリを使用するときは,中身もちゃんと見て,アルゴリズム見て欲しいですね.
あと入力のイメージとい出力のイメージをしっかりもつこと.

参考

LOOCVのドキュメント
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.LeaveOneOut.html

13
15
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
15