LoginSignup
167
145

More than 5 years have passed since last update.

sklearnの交差検証の種類とその動作

Last updated at Posted at 2018-09-04

sklearnで交差検証をする時に使うKFoldStratifiedKFoldShuffleSplitのそれぞれの動作について簡単にまとめ

KFold(K-分割交差検証)

概要

  • データをk個に分け,n個を訓練用,k-n個をテスト用として使う.
  • 分けられたn個のデータがテスト用として必ず1回使われるようにn回検定する.

オプション(引数)

  • n_split:データの分割数.つまりk.検定はここで指定した数値の回数おこなわれる.
  • shuffle:Trueなら連続する数字でグループ分けせず,ランダムにデータを選択する.
  • random_state:乱数のシードを指定できる.

import numpy as np
from sklearn.model_selection import KFold

x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])

kf = KFold(n_splits=5)
for train_index, test_index in kf.split(x, y):
    print("train_index:", train_index, "test_index:", test_index)

# output
# train_index: [2 3 4 5 6 7 8 9] test_index: [0 1]
# train_index: [0 1 4 5 6 7 8 9] test_index: [2 3]
# train_index: [0 1 2 3 6 7 8 9] test_index: [4 5]
# train_index: [0 1 2 3 4 5 8 9] test_index: [6 7]
# train_index: [0 1 2 3 4 5 6 7] test_index: [8 9]

StratifiedKFold(層状K分割)

概要

  • 分布に大きな不均衡がある場合に用いるKFold.
  • 分布の比率を維持したままデータを訓練用とテスト用に分割する.

オプション(引数)

  • KFoldと同じ.
  • n_splitがデータ数が最も少ないクラスのデータ数よりも多いと怒られる.

import numpy as np
from sklearn.model_selection import StratifiedKFold

x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])

skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(x, y):
    print("train_index:", train_index, "test_index:", test_index)

# output
# train_index: [1 6 7 8 9] test_index: [0 2 3 4 5]
# train_index: [0 2 3 4 5] test_index: [1 6 7 8 9]

ShuffleSplit(ランダム置換相互検証)

概要

  • 独立した訓練用・テスト用のデータ分割セットを指定した数だけ生成する.
  • データを最初にシャッフルしてから,訓練用とテスト用にデータを分割する.

オプション(引数)

  • n_splits:生成する分割セット数
  • test_size:テストに使うデータの割合(0~1の間で指定)
  • random_state:シャッフルする時の乱数のシード

import numpy as np
from sklearn.model_selection import ShuffleSplit

x = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4], [3, 4]])
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1])

ss = ShuffleSplit(n_splits=3, test_size=0.25, random_state=0)
for train_index, test_index in ss.split(x, y):
    print("train_index:", train_index, "test_index:", test_index)

# output
# train_index: [9 1 6 7 3 0 5] test_index: [2 8 4]
# train_index: [2 9 8 0 6 7 4] test_index: [3 5 1]
# train_index: [4 5 1 0 6 9 7] test_index: [2 3 8]
167
145
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
167
145