sklearnで交差検証をする時に使うKFold
,StratifiedKFold
,ShuffleSplit
のそれぞれの動作について簡単にまとめ
KFold(K-分割交差検証)
概要
- データをk個に分け,n個を訓練用,k-n個をテスト用として使う.
- 分けられたn個のデータがテスト用として必ず1回使われるようにn回検定する.
オプション(引数)
- n_split:データの分割数.つまりk.検定はここで指定した数値の回数おこなわれる.
- shuffle:Trueなら連続する数字でグループ分けせず,ランダムにデータを選択する.
- random_state:乱数のシードを指定できる.
例
import numpy as np
from sklearn.model_selection import KFold
x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
kf = KFold(n_splits=5)
for train_index, test_index in kf.split(x, y):
print("train_index:", train_index, "test_index:", test_index)
# output
# train_index: [2 3 4 5 6 7 8 9] test_index: [0 1]
# train_index: [0 1 4 5 6 7 8 9] test_index: [2 3]
# train_index: [0 1 2 3 6 7 8 9] test_index: [4 5]
# train_index: [0 1 2 3 4 5 8 9] test_index: [6 7]
# train_index: [0 1 2 3 4 5 6 7] test_index: [8 9]
StratifiedKFold(層状K分割)
概要
- 分布に大きな不均衡がある場合に用いるKFold.
- 分布の比率を維持したままデータを訓練用とテスト用に分割する.
オプション(引数)
- KFoldと同じ.
- n_splitがデータ数が最も少ないクラスのデータ数よりも多いと怒られる.
例
import numpy as np
from sklearn.model_selection import StratifiedKFold
x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(x, y):
print("train_index:", train_index, "test_index:", test_index)
# output
# train_index: [1 6 7 8 9] test_index: [0 2 3 4 5]
# train_index: [0 2 3 4 5] test_index: [1 6 7 8 9]
ShuffleSplit(ランダム置換相互検証)
概要
- 独立した訓練用・テスト用のデータ分割セットを指定した数だけ生成する.
- データを最初にシャッフルしてから,訓練用とテスト用にデータを分割する.
オプション(引数)
- n_splits:生成する分割セット数
- test_size:テストに使うデータの割合(0~1の間で指定)
- random_state:シャッフルする時の乱数のシード
例
import numpy as np
from sklearn.model_selection import ShuffleSplit
x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
ss = ShuffleSplit(n_splits=3, test_size=0.25, random_state=0)
for train_index, test_index in ss.split(x, y):
print("train_index:", train_index, "test_index:", test_index)
# output
# train_index: [9 1 6 7 3 0 5] test_index: [2 8 4]
# train_index: [2 9 8 0 6 7 4] test_index: [3 5 1]
# train_index: [4 5 1 0 6 9 7] test_index: [2 3 8]