1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pythonで機械学習 - Evaluation

Last updated at Posted at 2018-02-11

構築する予測モデルに用いるアルゴリズムの評価として、交差検証(Cross-Validation)による精度の評価を行います。

準備

必要なライブラリをインポートします。交差検証には、Scikit-LearnのStratifiedKFold(層化k-分割)を使用します。

import pandas as pd
from sklearn.model_selection import StratifiedKFold
import numpy as np
from sklearn.ensemble import RandomForestClassifier

整形済みデータを読み込みます。

train = pd.read_csv("train_prep.csv")

読み込んだデータから、検証に用いる説明変数と目的変数をそれぞれ準備します。

expvars = ["Pclass","SexInt","AgeFillNa","FareFillNa","IsAlone"] # 説明変数のリスト
data = train.copy()[expvars] # 説明変数
label = train["Survived"] # 目的変数

モデルの構築と評価

k-分割は、テストデータを一定数に分割(今回は10)し、9つが学習データ、1つが検証データという組み合わせを順次つくり、予測モデルを構築します。

clf = RandomForestClassifier()
skf = StratifiedKFold(n_splits=10,random_state=1)
scores = []
for train_ix,test_ix in skf.split(data,label): # テストデータを分割し、順次処理
    clf.fit(data.ix[train_ix],label.ix[train_ix]) # 予測モデルの構築
    score = clf.score(data.ix[test_ix],label.ix[test_ix]) # 予測モデルの精度評価を検証データで行う。
    scores.append(score)

精度の確認

精度を確認します。10回分のモデル構築と検証を実施していますが、平均は08.7%です。

print scores
print np.mean(scores) # スコアの平均
[0.69999999999999996, 0.77777777777777779, 0.7640449438202247, 0.8651685393258427, 0.84269662921348309, 0.797752808988764, 0.84269662921348309, 0.7865168539325843, 0.8314606741573034, 0.86363636363636365]
0.807175122007

モデルの選択

同様にして、複数のアルゴリズムで精度を比較し、最も精度が高いものを選択します。

アルゴリズムをインポートします。

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import Perceptron
from sklearn.linear_model import SGDClassifier
algs = [
    RandomForestClassifier,
    DecisionTreeClassifier,
    LogisticRegression,
    LinearSVC,
    KNeighborsClassifier,
    GaussianNB,
    Perceptron,
    SGDClassifier
]

精度を評価する処理を関数化します。

def cross_validate(clf,data,label):
    skf = StratifiedKFold(n_splits=10,random_state=1)
    scores = []
    for train_ix,test_ix in skf.split(data,label): # テストデータを分割し、順次処理
        clf.fit(data.ix[train_ix],label.ix[train_ix]) # 予測モデルの構築
        score = clf.score(data.ix[test_ix],label.ix[test_ix]) # 予測モデルの精度評価を検証データで行う。
        scores.append(score)
    return np.mean(scores)

各アルゴリズムによる交差検証を行い精度を確認してみます。

results = {}
for alg in algs:
    clf = alg() #分類器の作成
    score = cross_validate(clf,data,label) # 交差検証
    alg_name = str(type(clf)).split("'")[1].split(".")[-1] # インスタンスからアルゴリズム名を取得
    results[alg_name] = score # アルゴリズム名と精度を格納

結果を可視化すると、アルゴリズムの中ではランダムフォレストが最も精度が良いと言えます。

print pd.Series(results).sort_values(ascending=False)
RandomForestClassifier    0.813830
LogisticRegression        0.786749
DecisionTreeClassifier    0.785764
GaussianNB                0.762242
LinearSVC                 0.738409
KNeighborsClassifier      0.716363
Perceptron                0.593985
SGDClassifier             0.579110
dtype: float64

ロジスティック回帰

ロジスティック回帰は、各項目がどの程度、結果に影響しているかを定量的に把握できます。

clf = LogisticRegression()
clf.fit(data,label)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
clf.coef_
array([[ -1.03219247e+00,   2.53219867e+00,  -2.91694603e-02,
          1.80351166e-03,   9.32819823e-02]])
pd.DataFrame(columns=data.columns,data=clf.coef_)
Pclass SexInt AgeFillNa FareFillNa IsAlone
0 -1.032192 2.532199 -0.029169 0.001804 0.093282

戻る

1
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?