LoginSignup
0
1

More than 3 years have passed since last update.

sklearnで交差検証をする時に使うKFoldの備忘録

Posted at

sklearnで交差検証をする時に使うKFoldについての備忘録

import numpy as np
from sklearn.model_selection import KFold

# ###################
# サンプルデータの生成
# ###################

# 乱数の初期値設定
rand = np.random.RandomState(seed=71)
data = np.linspace(0, 1, 10000) # 0~1まで等間隔に10000個の実数を取得

分割器を作る。ここでは、シャッフルした上で4つに分割する分割器を作成

kf = KFold(n_splits=4, random_state=71, shuffle=True)

実際に分割するには、

kf.split(data)

とするが、返却されるのが、イテレータブルオブジェクトの為、for文と組み合わせて取得する。

for i in kf.split(data):
    print(i)

# > (array([   0,    1,    2, ..., 9996, 9998, 9999]), array([   6,   10,   11, ..., 9994, 9995, 9997]))
# > (array([   0,    1,    2, ..., 9997, 9998, 9999]), array([   3,    5,   15, ..., 9981, 9985, 9987]))
# > (array([   2,    3,    4, ..., 9995, 9996, 9997]), array([   0,    1,    7, ..., 9992, 9998, 9999]))
# > (array([   0,    1,    3, ..., 9997, 9998, 9999]), array([   2,    4,    8, ..., 9990, 9993, 9996]))

全部で、10000件あるので、2500件ずつデータを4つに分割し、3(7500件):1(2500件) の組み合わせを作っている。(4パターン)

注意

splitで返却されるのは、データのインデックであって、データ自体ではないことに注意が必要です!!
なので、こんな感じにする必要があります。

for i, (train_index, test_index) in enumerate(kf.split(data)):
    train_data, test_data = data[train_index], data[test_index]
    print(train_data)

# output
# > Trainデータ: [0.00000000e+00 1.00010001e-04 2.00020002e-04 ... 9.99699970e-01
# >  9.99899990e-01 1.00000000e+00]
# > Testデータ: [6.00060006e-04 1.00010001e-03 1.10011001e-03 ... 9.99499950e-01
# >  9.99599960e-01 9.99799980e-01]
# > ========================================================
# > Trainデータ: [0.00000000e+00 1.00010001e-04 2.00020002e-04 ... 9.99799980e-01
# >  9.99899990e-01 1.00000000e+00]
# > Testデータ: [3.00030003e-04 5.00050005e-04 1.50015002e-03 ... 9.98199820e-01
# >  9.98599860e-01 9.98799880e-01]
# > ========================================================
# > Trainデータ: [2.00020002e-04 3.00030003e-04 4.00040004e-04 ... 9.99599960e-01
# >  9.99699970e-01 9.99799980e-01]
# > Testデータ: [0.00000000e+00 1.00010001e-04 7.00070007e-04 ... 9.99299930e-01
# >  9.99899990e-01 1.00000000e+00]
# > ========================================================
# > Trainデータ: [0.00000000e+00 1.00010001e-04 3.00030003e-04 ... 9.99799980e-01
# >  9.99899990e-01 1.00000000e+00]
# > Testデータ: [2.00020002e-04 4.00040004e-04 8.00080008e-04 ... 9.99099910e-01
# >  9.99399940e-01 9.99699970e-01]
# > ========================================================

KFoldの備忘録でした。

おしまい

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1