#概要
機械学習において、作成した学習モデルの汎化性能を調べるために交差検証という手法が用いられる。
交差検証を行うとき、学習曲線と検証曲線を描くことでより多面的に情報が得られる。
Matplotlibでグラフを描くためのサンプルコードを備忘も兼ねて記載する。
##学習曲線とは
評価尺度に正解率を使用した場合、横軸にデータ数・縦軸に正解率をとったグラフ。
訓練データの評価結果と検証データの評価結果を合わせて表示することで、2つの相対比較が行える。
そのモデルが過学習なのか学習不足なのか、さらにデータ数を増やすことが問題解決に役立つかといったことがわかる。
##検証曲線とは
学習曲線と異なり、パラメータ値を横軸にとる。
パラメータチューニングに役立つ。
#Pythonで実装
- 学習アルゴリズム:SVM
- データセット:Kaggleの提供しているタイタニックの訓練データセット
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import learning_curve
from sklearn.model_selection import validation_curve
# あらかじめ作成しておいたタイタニックの学習用データを読み込み
data = pd.read_csv('train.csv')
data.info()
# output
'''
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
'''
# 前処理
data = data.dropna(subset=['Age','Embarked'])
data['Sex'] = data['Sex'].map({'male':0, 'female':1})
data['Embarked'] = data['Embarked'].map({'S':0, 'C':1, 'Q':2})
# カラム'Pclass', 'Sex', 'Fare', 'Embarked'を説明変数に利用する
X = data[['Pclass','Sex','Fare','Embarked']]
T = data['Survived']
# ここではテストデータは使わないが、形式上学習用とテスト用に分けておく
X_train, X_test, T_train, T_test = train_test_split(X, T, test_size=0.2, stratify=T, random_state=0)
pipe = make_pipeline(StandardScaler(),
SVC()) # パラメータはデフォルト
# learning_curve関数で交差検証(k=10)
train_sizes, train_scores, test_scores = learning_curve(estimator=pipe,
X = X_train, y = T_train,
train_sizes=np.linspace(0.1, 1.0, 10), # 与えたデータセットの何割を使用するかを指定
cv=10, n_jobs=1)
# 学習曲線の描画
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
plt.figure(figsize=(8,6))
plt.plot(train_sizes, train_mean, marker='o', label='Train accuracy')
plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.2)
plt.plot(train_sizes, test_mean, marker='s', linestyle='--', label='Validation accuracy')
plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.2)
plt.grid()
plt.title('Learning curve', fontsize=16)
plt.xlabel('Number of training data sizes', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend(fontsize=12)
plt.ylim([0.5, 1.05])
plt.show()
ここではtrain_test_splitで学習用に分けたデータのうち10分の1を検証用サブセット、残りを訓練用サブセットとし、計10回の交差検証を行っている。
折れ線グラフの意味は下記の通り。
- 青の実線:訓練用サブセットで予測した場合の正解率
- 赤の点線:検証用サブセットで予測した場合の正解率
データセットの数が400個あたりに達すると、訓練データと検証データの正解率が収束し、両方に対して良いスコアが出ている。
→これ以上データを集めることにはあまり意味はなさそうということがわかる
データセットが400未満の時のスコアを見ると、訓練データに対する正解率は高いが検証データに対する正解率が低い。
これは訓練データに過剰に適合してしまい、未知のデータに対する予測精度が落ちていることを示している。(これを過学習、またはバリアンスが高いという)
続いて検証曲線を描画する。
# SVMのパラメータgammaを変化させる
param_range = [1e-5,1e-4,1e-3,1e-2,1e-1,1,1e1,1e2,1e3]
# validation_curve関数で交差検証
train_scores, test_scores = validation_curve(estimator=pipe,
X=X_train, y=T_train,
param_name='svc__gamma',
param_range=param_range, cv=10)
# 検証曲線の描画
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
plt.figure(figsize=(8,6))
plt.plot(param_range, train_mean, marker='o', label='Train accuracy')
plt.fill_between(param_range, train_mean + train_std, train_mean - train_std, alpha=0.2)
plt.plot(param_range, test_mean, marker='s', linestyle='--', label='Validation accuracy')
plt.fill_between(param_range, test_mean + test_std, test_mean - test_std, alpha=0.2)
plt.grid()
plt.xscale('log')
plt.title('Validation curve(gamma)', fontsize=16)
plt.xlabel('Parameter gamma', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend(fontsize=12)
plt.ylim([0.5, 1.05])
plt.show()
SVMのパラメータgammaを1e-5(10の-5乗)から10ずつ乗算して1e3まで変化させた場合の正解率の変化を示した。
gamma=1e-2より小さいと正解率が低い(学習不足、バイアスが高いという)
gamma=1より大きいと過学習
という傾向がつかめる。
SVMの代表的なパラメータにはCとgammaがあり、ここではCがデフォルト値の元でという条件付きだが、gammaは1e-2~1のあたりに設定するのがよさそうという判断ができる。