LoginSignup
20
15

More than 5 years have passed since last update.

Cross-validation: KFold と StratifiledKFold の性能の違い

Last updated at Posted at 2019-01-28

目的

本ページの目的は交差検証によって、構築された機械学習モデルの予測精度が変化するかを確かめる。
そのために、以下の内容で話を進める。

  • FIFAのサッカーのデータを利用する
    • 分類対象のラベルは列名Man of the Matchとした。
    • 列名Man of the Matchには、いわゆる MVPの選手がいるかどうかを判定する2値が含まれている
  • 10交差検証(KFodl関数,シャッフル有り) と Ten-fold-stratified-cross-validation (StratifiledKFold関数)を比較する
  • 推定精度はパラメータによって敏感(センシティブ)に変化することがある。
    • パラメーターチューニングにセンシティブなSVMなどのモデルは利用しない。
    • 深層学習は大規模データ向きなので(著者感覚では深層学習はチェーンソーぐらい)、包丁で切れるデータ(FIFAのサッカーのデータ)にはランダムフォレストを利用する。
    • パラメーターチューニングは基本的に行わない。

背景

この記事の前にCross-validation: KFold と StratifiledKFold の違いを参照されたい

本ページを読み終えて理解すること

  • 交差検証による分類対象(クラス)の振り分けによって精度が前後する
  • Ten-fold-stratified-cross-validation の精度が10交差検証よりも若干良い。
  • Ten-fold-stratified-cross-validation の精度の標準偏差が10交差検証よりも若干低い
    • ある意味当たり前だが、クラスの割り振りによる精度のばらつきが減った。

pandas の使い方

データの読み込み

  • sklearn では様々なパッケージを利用する際に pandas が便利である。
  • 重要変数
    • data_setは機械学習の用語である特徴量(もしくは特徴変数) を表す
    • target_setは機械学習の用語であるクラス (分類対象, setosa などはクラスラベル)を表す
  • データ (Kaggle)
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
# 見た目を綺麗にするもの
import matplotlib.pyplot as plt
import seaborn as sns
data = pd.read_csv("FIFA_data.csv")
pd.set_option('display.max_rows', 10)
print(data.columns)
feature_names = [i for i in data.columns if data[i].dtype in [np.int64]]
target_set = data['Man of the Match'] 
print(feature_names)
data_set = data[feature_names]
#data_set =  data.drop('Goal Scored', axis=1)
#data_set =  data_set.drop(('Date',"Team",'Opponent'), axis=1)
Index(['Date', 'Team', 'Opponent', 'Goal Scored', 'Ball Possession %',
       'Attempts', 'On-Target', 'Off-Target', 'Blocked', 'Corners', 'Offsides',
       'Free Kicks', 'Saves', 'Pass Accuracy %', 'Passes',
       'Distance Covered (Kms)', 'Fouls Committed', 'Yellow Card',
       'Yellow & Red', 'Red', 'Man of the Match', '1st Goal', 'Round', 'PSO',
       'Goals in PSO', 'Own goals', 'Own goal Time'],
      dtype='object')
['Goal Scored', 'Ball Possession %', 'Attempts', 'On-Target', 'Off-Target', 'Blocked', 'Corners', 'Offsides', 'Free Kicks', 'Saves', 'Pass Accuracy %', 'Passes', 'Distance Covered (Kms)', 'Fouls Committed', 'Yellow Card', 'Yellow & Red', 'Red', 'Goals in PSO']

確認1

クラスラベルの中身を表示し、サマーリーを確認

display(target_set)
display(target_set.describe())
#target_set.hist()
0      Yes
1       No
2       No
3      Yes
4       No
      ... 
123     No
124    Yes
125     No
126    Yes
127     No
Name: Man of the Match, Length: 128, dtype: object



count     128
unique      2
top       Yes
freq       64
Name: Man of the Match, dtype: object

確認2

特徴変数の中身を表示し、サマーリーを確認

display(data_set)
display(data_set.describe())

各特徴量のデータをチェック

from pylab import rcParams
rcParams['figure.figsize'] = 10, 7
data_set.hist()
plt.tight_layout()

output_13_0.png

モデルの定義

  • ランダムフォレストを利用するため、RandomForestClassifier を利用
  • 交差検証の種類
    • kf は 10交差検証 を行うための準備
    • skf は Ten-fold-stratified-cross-validation を行うための準備
from sklearn.model_selection import StratifiedKFold, cross_validate, KFold
# 利用するモデルの定義
model = RandomForestClassifier(n_estimators = 1000)
# データをどのように分割するか?
np.random.rand(4) 
kf = KFold(n_splits=10,
            shuffle=True,
            random_state=0)
skf = StratifiedKFold(n_splits=10,
                      shuffle=True,
                      random_state=0)

指標の計算

  • 指標
    • scoring に Accuracy と Kappa 係数を指定
  • cross_validate の引数のcvに交差検証の種類を設定
  • score_kf にはランダムフォレストの10交差検証の指標を計算した結果が代入
  • score_skf にはランダムフォレストのTen-fold-stratified-cross-validationの指標を計算した結果が代入
%%time
import pprint
# 指標を計算するため
from sklearn.metrics import accuracy_score, cohen_kappa_score, make_scorer, f1_score, recall_score
scoring = {'accuracy': make_scorer(accuracy_score),
           'kappa': make_scorer(cohen_kappa_score)}


scores_kf = cross_validate(model, 
                        data_set,   # 
                        target_set,
                        cv=kf, 
                        n_jobs = -1,
                        scoring=scoring)
scores_skf = cross_validate(model, 
                        data_set,   # 
                        target_set,
                        cv=skf, 
                        n_jobs = -1,
                        scoring=scoring)

CPU times: user 86.8 ms, sys: 4.44 ms, total: 91.2 ms
Wall time: 23.1 s
# おそらく以下の数値から過学習を引き起こしている
pprint.pprint(scores_kf)
pprint.pprint(scores_skf)
{'fit_time': array([1.57576084, 1.56514001, 1.55024195, 1.53922677, 1.96232772,
       1.94163394, 1.95994902, 2.01066685, 1.51928878, 1.56764865]),
 'score_time': array([0.259655  , 0.26242495, 0.25388193, 0.26137924, 0.2895453 ,
       0.29229116, 0.28499889, 0.30088711, 0.16101408, 0.16763806]),
 'test_accuracy': array([0.61538462, 0.92307692, 0.92307692, 0.69230769, 0.38461538,
       0.61538462, 0.30769231, 0.76923077, 0.75      , 0.66666667]),
 'test_kappa': array([ 0.19753086,  0.84337349,  0.84705882,  0.36585366, -0.15555556,
        0.21686747, -0.09345794,  0.53012048,  0.4375    ,  0.31428571]),
 'train_accuracy': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'train_kappa': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])}
{'fit_time': array([1.91939378, 1.95008111, 1.92234015, 1.92945504, 2.14256263,
       2.06773615, 2.03872895, 2.1094451 , 0.98216105, 0.98018098]),
 'score_time': array([0.31813312, 0.33372688, 0.32486987, 0.34145498, 0.24893427,
       0.27141714, 0.25123096, 0.27676487, 0.16585112, 0.1651299 ]),
 'test_accuracy': array([0.92857143, 0.71428571, 0.71428571, 0.71428571, 0.83333333,
       0.58333333, 0.75      , 0.33333333, 0.91666667, 0.75      ]),
 'test_kappa': array([ 0.85714286,  0.42857143,  0.42857143,  0.42857143,  0.66666667,
        0.16666667,  0.5       , -0.33333333,  0.83333333,  0.5       ]),
 'train_accuracy': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'train_kappa': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])}
plt.plot(scores_kf['test_accuracy'],label='accuracy_kf')
plt.plot(scores_skf['test_accuracy'],label='accuracy_skf')
plt.plot(scores_kf['test_kappa'],label='kappa_kf')
plt.plot(scores_skf['test_kappa'],label='kappa_skf')
plt.legend(loc="best")
plt.xlabel("#CV")
plt.ylabel("Index")

Text(0,0.5,'Index')

output_19_1.png

結果

  • Ten-fold-stratified-cross-validationは10交差検証(KFodl関数,シャッフル有り) よりも精度が高い。
    • 精度の指標としてはAccuracy (正解率)とKappa係数を利用した
print("KFoldの正解率__________" + str(scores_kf['test_accuracy'].mean()))
print("StratifiedKFoldの正解率" + str(scores_skf['test_accuracy'].mean()))
print("KFoldのKappa__________" + str(scores_kf['test_kappa'].mean()))
print("StratifiedKFoldのKappa" + str(scores_skf['test_kappa'].mean()))
print("KFoldの正解率からStratifiedKFoldの正解率を引いた数値:" + str(scores_kf['test_accuracy'].mean() - scores_skf['test_accuracy'].mean()))
print("KFoldのKappaからStratifiedKFoldのKappaを引いた数値:" + str(scores_kf['test_kappa'].mean() - scores_skf['test_kappa'].mean()))

KFoldの正解率__________0.6878205128205128
StratifiedKFoldの正解率0.7035714285714285
KFoldのKappa__________0.3585261916582403
StratifiedKFoldのKappa0.40714285714285714
KFoldの正解率からStratifiedKFoldの正解率を引いた数値:-0.015750915750915695
KFoldのKappaからStratifiedKFoldのKappaを引いた数値:-0.04861666548461685

議論:妥当性の確認

  • Ten-fold-stratified-cross-validationを100回行った平均は、10交差検証を100回行った平均よりも高かった。
    • 100回平均の正解率は 0.01程度 改善
    • 100回平均のKappaは 0.02程度 改善
  • 100回計算した正解率とKappaの標準偏差からTen-fold-stratified-cross-validationの方が10交差検証よりも安定した性能を出すことができる
%%time
data_scores_kf_accuracy = []
data_scores_skf_accuracy = []
data_scores_kf_kappa = []
data_scores_skf_kappa = []
for i in range(0,100):
    kf = KFold(n_splits=10,
            shuffle=True,
            random_state=i)
    skf = StratifiedKFold(n_splits=10,
                      shuffle=True,
                      random_state=i)
    scores_kf = cross_validate(model, 
                        data_set,   # 
                        target_set,
                        cv=kf, 
                        n_jobs = -1,
                        scoring=scoring)
    scores_skf = cross_validate(model, 
                        data_set,   # 
                        target_set,
                        cv=skf, 
                        n_jobs = -1,
                        scoring=scoring)
    data_scores_kf_accuracy.append(scores_kf['test_accuracy'].mean())
    data_scores_skf_accuracy.append(scores_skf['test_accuracy'].mean())
    data_scores_kf_kappa.append(scores_kf['test_kappa'].mean())
    data_scores_skf_kappa.append(scores_skf['test_kappa'].mean())
CPU times: user 7.74 s, sys: 299 ms, total: 8.03 s
Wall time: 24min 31s
from statistics import mean, median,variance,stdev
print("KFoldの正解率の平均__________" + str(mean(data_scores_kf_accuracy)))
print("StratifiedKFoldの正解率の平均" + str(mean(data_scores_skf_accuracy)))
print("KFoldのKappaの平均__________" + str(mean(data_scores_kf_kappa)))
print("StratifiedKFoldのKappaの平均" + str(mean(data_scores_skf_kappa)))
print("KFoldからStratifiedKFoldの正解率の平均を引いた数値:" + str(mean(data_scores_kf_accuracy) - mean(data_scores_skf_accuracy)))
print("KFoldからStratifiedKFoldのKappaの平均を引いた数値:" + str(mean(data_scores_kf_kappa) - mean(data_scores_skf_kappa)))

KFoldの正解率の平均__________0.6808076923076923
StratifiedKFoldの正解率の平均0.6887023809523809
KFoldのKappaの平均__________0.3547788080897784
StratifiedKFoldのKappaの平均0.3774047619047619
KFoldからStratifiedKFoldの正解率の平均を引いた数値:-0.007894688644688563
KFoldからStratifiedKFoldのKappaの平均を引いた数値:-0.0226259538149835
print("KFoldの正解率の標準偏差__________" + str(stdev(data_scores_kf_accuracy)))
print("StratifiedKFoldの正解率の標準偏差" + str(stdev(data_scores_skf_accuracy)))
print("KFoldのKappaの標準偏差__________" + str(stdev(data_scores_kf_kappa)))
print("StratifiedKFoldのKappaの標準偏差" + str(stdev(data_scores_skf_kappa)))
print("KFoldからStratifiedKFoldの正解率の標準偏差を引いた数値:" + str(stdev(data_scores_kf_accuracy) - stdev(data_scores_skf_accuracy)))
print("KFoldからStratifiedKFoldのKappaの標準偏差を引いた数値:" + str(stdev(data_scores_kf_kappa) - stdev(data_scores_skf_kappa)))
KFoldの正解率の標準偏差__________0.020125482966886734
StratifiedKFoldの正解率の標準偏差0.01910577266840414
KFoldのKappaの標準偏差__________0.04198347226190002
StratifiedKFoldのKappaの標準偏差0.038211545336808345
KFoldからStratifiedKFoldの正解率の標準偏差を引いた数値:0.0010197102984825929
KFoldからStratifiedKFoldのKappaの標準偏差を引いた数値:0.0037719269250916787
20
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
15