LoginSignup
2
1

More than 1 year has passed since last update.

scikit-learn データセット分割の覚え書き

Last updated at Posted at 2022-09-05

この記事の内容

scikit-learnに付属しているデータセット分割法の覚え書きです。

  • ホールドアウト法
  • K-分割交差検証(KFold)
  • 層化K-分割交差検証(Stratified KFold)
  • グループ付き交差検証(Group KFold)
  • 時系列データ分割(Time Series Split)

データセットのインポート

scikit learnのデータセットとして利用できるiris datasetを使用します。

import pandas as pd
from sklearn.datasets import load_iris

iris = load_iris()
df_iris = pd.DataFrame(data=iris.data,columns=iris.feature_names)

ホールドアウト法

単純にデータセットを訓練データとテストデータに分割します。(デフォルト=7:3)

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df_iris,iris.target)

#第1引数:データセット、第2引数:データセットの結果ラベル
#追加引数(train_size = 0.6, random_state=0)

K-分割交差検証(KFold)

ホールドアウト法では学習に利用されないデータがある。訓練データとテストデータの組み合わせをK個分作成し、それぞれ学習・予測を行い、モデルの精度を包括的に見る。

from sklearn.model_selection import KFold

n_split=3
kf=KFold(n_split,shuffle=True,random_state=0)
split_ndx=[[] for i in range(n_split)]

for fold,(train_ndx,test_ndx) in enumerate(kf.split(iris.data,iris.target)):
    split_ndx[fold]=[train_ndx,test_ndx]

層化K-分割交差検証(Stratified KFold)

特定の説明変数におけるある要素の度数が極端に低い場合や要素の種類が多い場合に使用します。
スクリーンショット 2022-09-05 19.18.30.png
※目的変数→説明変数(画像は間違えています。)

from sklearn.model_selection import StratifiedKFold

n_split=3
kf = StratifiedKFold(n_split,shuffle=True,random_state=0)
split_ndx=[[] for i in range(n_split)]

for fold,(train_index, test_index) in enumerate(kf.split(iris.data,iris.target)):
    split_ndx[fold]=[train_ndx,test_ndx]

グループ付き交差検証(Group KFold)

学習データとテストデータが特定の説明変数において同じ要素を持たないように分割する。irisデータセットに対してこの分割法が有効とは思いませんが、あくまでも使い方として表示します。
スクリーンショット 2022-09-05 19.26.43.png
※目的変数→説明変数(画像は間違えています。)

from sklearn.model_selection import GroupKFold

n_split=22
kf = GroupKFold(n_split)
split_ndx=[[] for i in range(n_split)]

for fold,(train_index, test_index) in enumerate(kf.split(df_iris[['sepal length (cm)','sepal width (cm)']], iris.target,df_iris['petal width (cm)'])):
    split_ndx[fold]=[train_ndx,test_ndx]

時系列データ分割(TimeSeriesSplit)

データが時系列を持っている場合に使用できます。今回のirisデータセットは時系列を持たないため、あくまでも使用法という名目でコードを示します。
スクリーンショット 2022-09-05 19.41.31.png

from sklearn.model_selection import TimeSeriesSplit

n_split=15
tscv = TimeSeriesSplit(n_split)
split_ndx=[[] for i in range(n_split)]

for fold,(train_index, test_index) in enumerate(tscv.split(X)):
    split_ndx[fold]=[train_ndx,test_ndx]

参考

詳しくはこちらのサイトを参照ください。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1