25
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

複数ラベルのstratified CVをやるならiterstratを使おう

Last updated at Posted at 2020-10-16

機械学習モデルは学習データに過学習しがちなので手元のデータを学習データとテストデータに分割して性能評価(validation)するというのをよくやります。この分割のやり方いろいろについてはscikit-learnのvalidationについての解説が分かりやすいです。
https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation

1. scikit-learnのStratifiedKFoldの場合

判別(classification)をやる場合は正解ラベルの分布が同じになるように分割したいので、scikit-learnのStratifiedKFoldをよく使うと思います。

python
import numpy as np
from sklearn.model_selection import StratifiedKFold

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 1, 1])
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
結果
TRAIN: [1 3] TEST: [0 2]
TRAIN: [0 2] TEST: [1 3]

ラベルが1つしかない場合はこれをで充分なんですが、複数ラベルには対応していません。

python
import numpy as np
from sklearn.model_selection import StratifiedKFold

X = np.array([[1,2], [3,4], [1,2], [3,4], [1,2], [3,4], [1,2], [3,4]])
y = np.array([[0,0], [0,0], [0,1], [0,1], [1,1], [1,1], [1,0], [1,0]])
skf = StratifiedKFold(n_splits=2)

for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
結果
ValueError: Supported target types are: ('binary', 'multiclass'). Got 'multilabel-indicator' instead.

2. iterstratの場合

iterstratなら複数ラベルに対応してくれます
https://github.com/trent-b/iterative-stratification

インストール

terminal
pip install iterative-stratification

使い方(shuffleしない場合)

python
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np

X = np.array([[1,2], [3,4], [1,2], [3,4], [1,2], [3,4], [1,2], [3,4]])
y = np.array([[0,0], [0,0], [0,1], [0,1], [1,1], [1,1], [1,0], [1,0]])

mskf = MultilabelStratifiedKFold(n_splits=2, shuffle=True, random_state=0)

for train_index, test_index in mskf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
結果
TRAIN: [0 3 4 6] TEST: [1 2 5 7]
TRAIN: [1 2 5 7] TEST: [0 3 4 6]

使い方(shuffleする場合)

python
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
import numpy as np

X = np.array([[1,2], [3,4], [1,2], [3,4], [1,2], [3,4], [1,2], [3,4]])
y = np.array([[0,0], [0,0], [0,1], [0,1], [1,1], [1,1], [1,0], [1,0]])

msss = MultilabelStratifiedShuffleSplit(n_splits=3, test_size=0.5, random_state=0)

for train_index, test_index in msss.split(X, y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
結果
TRAIN: [1 2 5 7] TEST: [0 3 4 6]
TRAIN: [2 3 6 7] TEST: [0 1 4 5]
TRAIN: [1 2 5 6] TEST: [0 3 4 7]

使い方(bootstrapする場合)

python
from iterstrat.ml_stratifiers import RepeatedMultilabelStratifiedKFold
import numpy as np

X = np.array([[1,2], [3,4], [1,2], [3,4], [1,2], [3,4], [1,2], [3,4]])
y = np.array([[0,0], [0,0], [0,1], [0,1], [1,1], [1,1], [1,0], [1,0]])

rmskf = RepeatedMultilabelStratifiedKFold(n_splits=2, n_repeats=2, random_state=0)

for train_index, test_index in rmskf.split(X, y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
結果
TRAIN: [0 3 4 6] TEST: [1 2 5 7]
TRAIN: [1 2 5 7] TEST: [0 3 4 6]
TRAIN: [0 1 4 5] TEST: [2 3 6 7]
TRAIN: [2 3 6 7] TEST: [0 1 4 5]
25
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?