sklearnで交差検証をする時に使うKFoldについての備忘録
import numpy as np
from sklearn.model_selection import KFold
# ###################
# サンプルデータの生成
# ###################
# 乱数の初期値設定
rand = np.random.RandomState(seed=71)
data = np.linspace(0, 1, 10000) # 0~1まで等間隔に10000個の実数を取得
分割器を作る。ここでは、シャッフルした上で4つに分割する分割器を作成
kf = KFold(n_splits=4, random_state=71, shuffle=True)
実際に分割するには、
kf.split(data)
とするが、返却されるのが、イテレータブルオブジェクトの為、for文と組み合わせて取得する。
for i in kf.split(data):
print(i)
# > (array([ 0, 1, 2, ..., 9996, 9998, 9999]), array([ 6, 10, 11, ..., 9994, 9995, 9997]))
# > (array([ 0, 1, 2, ..., 9997, 9998, 9999]), array([ 3, 5, 15, ..., 9981, 9985, 9987]))
# > (array([ 2, 3, 4, ..., 9995, 9996, 9997]), array([ 0, 1, 7, ..., 9992, 9998, 9999]))
# > (array([ 0, 1, 3, ..., 9997, 9998, 9999]), array([ 2, 4, 8, ..., 9990, 9993, 9996]))
全部で、10000件あるので、2500件ずつデータを4つに分割し、3(7500件):1(2500件) の組み合わせを作っている。(4パターン)
注意
splitで返却されるのは、データのインデックであって、データ自体ではないことに注意が必要です!!
なので、こんな感じにする必要があります。
for i, (train_index, test_index) in enumerate(kf.split(data)):
train_data, test_data = data[train_index], data[test_index]
print(train_data)
# output
# > Trainデータ: [0.00000000e+00 1.00010001e-04 2.00020002e-04 ... 9.99699970e-01
# > 9.99899990e-01 1.00000000e+00]
# > Testデータ: [6.00060006e-04 1.00010001e-03 1.10011001e-03 ... 9.99499950e-01
# > 9.99599960e-01 9.99799980e-01]
# > ========================================================
# > Trainデータ: [0.00000000e+00 1.00010001e-04 2.00020002e-04 ... 9.99799980e-01
# > 9.99899990e-01 1.00000000e+00]
# > Testデータ: [3.00030003e-04 5.00050005e-04 1.50015002e-03 ... 9.98199820e-01
# > 9.98599860e-01 9.98799880e-01]
# > ========================================================
# > Trainデータ: [2.00020002e-04 3.00030003e-04 4.00040004e-04 ... 9.99599960e-01
# > 9.99699970e-01 9.99799980e-01]
# > Testデータ: [0.00000000e+00 1.00010001e-04 7.00070007e-04 ... 9.99299930e-01
# > 9.99899990e-01 1.00000000e+00]
# > ========================================================
# > Trainデータ: [0.00000000e+00 1.00010001e-04 3.00030003e-04 ... 9.99799980e-01
# > 9.99899990e-01 1.00000000e+00]
# > Testデータ: [2.00020002e-04 4.00040004e-04 8.00080008e-04 ... 9.99099910e-01
# > 9.99399940e-01 9.99699970e-01]
# > ========================================================
KFoldの備忘録でした。
おしまい