6
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Pythonで学習曲線と検証曲線を描く

Last updated at Posted at 2021-06-28

#概要
機械学習において、作成した学習モデルの汎化性能を調べるために交差検証という手法が用いられる。
交差検証を行うとき、学習曲線と検証曲線を描くことでより多面的に情報が得られる。
Matplotlibでグラフを描くためのサンプルコードを備忘も兼ねて記載する。
##学習曲線とは
評価尺度に正解率を使用した場合、横軸にデータ数・縦軸に正解率をとったグラフ。
訓練データの評価結果と検証データの評価結果を合わせて表示することで、2つの相対比較が行える。
そのモデルが過学習なのか学習不足なのか、さらにデータ数を増やすことが問題解決に役立つかといったことがわかる。
##検証曲線とは
学習曲線と異なり、パラメータ値を横軸にとる。
パラメータチューニングに役立つ。
#Pythonで実装

  • 学習アルゴリズム:SVM
  • データセット:Kaggleの提供しているタイタニックの訓練データセット
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import learning_curve
from sklearn.model_selection import validation_curve

# あらかじめ作成しておいたタイタニックの学習用データを読み込み
data = pd.read_csv('train.csv')

data.info()
# output
'''
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
'''

# 前処理
data = data.dropna(subset=['Age','Embarked'])
data['Sex'] = data['Sex'].map({'male':0, 'female':1})
data['Embarked'] = data['Embarked'].map({'S':0, 'C':1, 'Q':2})

# カラム'Pclass', 'Sex', 'Fare', 'Embarked'を説明変数に利用する
X = data[['Pclass','Sex','Fare','Embarked']]
T = data['Survived']

# ここではテストデータは使わないが、形式上学習用とテスト用に分けておく
X_train, X_test, T_train, T_test = train_test_split(X, T, test_size=0.2, stratify=T, random_state=0)

pipe = make_pipeline(StandardScaler(),
                     SVC()) # パラメータはデフォルト
# learning_curve関数で交差検証(k=10)
train_sizes, train_scores, test_scores = learning_curve(estimator=pipe,
                                                        X = X_train, y = T_train,
                                                        train_sizes=np.linspace(0.1, 1.0, 10), # 与えたデータセットの何割を使用するかを指定
                                                        cv=10, n_jobs=1)
# 学習曲線の描画
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
plt.figure(figsize=(8,6))
plt.plot(train_sizes, train_mean, marker='o', label='Train accuracy')
plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.2)
plt.plot(train_sizes, test_mean, marker='s', linestyle='--', label='Validation accuracy')
plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.2)
plt.grid()
plt.title('Learning curve', fontsize=16)
plt.xlabel('Number of training data sizes', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend(fontsize=12)
plt.ylim([0.5, 1.05])
plt.show()

learning_curve.png
ここではtrain_test_splitで学習用に分けたデータのうち10分の1を検証用サブセット、残りを訓練用サブセットとし、計10回の交差検証を行っている。
折れ線グラフの意味は下記の通り。

  • 青の実線:訓練用サブセットで予測した場合の正解率
  • 赤の点線:検証用サブセットで予測した場合の正解率

データセットの数が400個あたりに達すると、訓練データと検証データの正解率が収束し、両方に対して良いスコアが出ている。
 →これ以上データを集めることにはあまり意味はなさそうということがわかる

データセットが400未満の時のスコアを見ると、訓練データに対する正解率は高いが検証データに対する正解率が低い。
これは訓練データに過剰に適合してしまい、未知のデータに対する予測精度が落ちていることを示している。(これを過学習、またはバリアンスが高いという)

続いて検証曲線を描画する。

# SVMのパラメータgammaを変化させる
param_range = [1e-5,1e-4,1e-3,1e-2,1e-1,1,1e1,1e2,1e3]
# validation_curve関数で交差検証
train_scores, test_scores = validation_curve(estimator=pipe,
                                             X=X_train, y=T_train,
                                             param_name='svc__gamma',
                                             param_range=param_range, cv=10)

# 検証曲線の描画
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)
plt.figure(figsize=(8,6))
plt.plot(param_range, train_mean, marker='o', label='Train accuracy')
plt.fill_between(param_range, train_mean + train_std, train_mean - train_std, alpha=0.2)
plt.plot(param_range, test_mean, marker='s', linestyle='--', label='Validation accuracy')
plt.fill_between(param_range, test_mean + test_std, test_mean - test_std, alpha=0.2)
plt.grid()
plt.xscale('log')
plt.title('Validation curve(gamma)', fontsize=16)
plt.xlabel('Parameter gamma', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend(fontsize=12)
plt.ylim([0.5, 1.05])
plt.show()

validation_curve.png
SVMのパラメータgammaを1e-5(10の-5乗)から10ずつ乗算して1e3まで変化させた場合の正解率の変化を示した。
 gamma=1e-2より小さいと正解率が低い(学習不足、バイアスが高いという)
 gamma=1より大きいと過学習
という傾向がつかめる。
SVMの代表的なパラメータにはCとgammaがあり、ここではCがデフォルト値の元でという条件付きだが、gammaは1e-2~1のあたりに設定するのがよさそうという判断ができる。

6
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?