32
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

半歩ずつ進める機械学習 ~scikit-learn交差検証のデータ分割方法について~

Posted at

##前回まで
前回の投稿で、cross_val_scoreでの交差検証の結果が、それぞれ大幅に異なる原因について
データセットの並びの偏り、または交差検証の際のデータの抜き出し方の偏りに拠る物だ、と仮定しました。
今回はメジャーな3種のデータ分割方法について自分なりに纏めます。

##分割方法
メジャーな方法として以下3種があります

  • Kfold
  • Stratified Kfold
  • Group Kfold

では、それぞれの特徴を個別に書いていきます

##Kfold
特にデータの内容に拘らず、単純に分割する方法です

image.png

実装方法としては以下の様な形です
尚、env_dataには既に506件のデータが入っているとお考えください。

foldMethod.py
from sklearn.model_selection import KFold

kf = KFold(n_splits = 5)
splitter = kf.split(env_data)#env_data分割用ジェネレータ
for train,test in splitter:
    print(test)
"""
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101]...
"""

データそのものではなく、抽出したインデックスを得る事が出来ます。

少し脱線しますが、自分はpythonに慣れていないので、少し困惑したのですが
このsplitterに直接インデックスのデータが入っているワケではありません
これはテスト用・学習用のインデックスを返してくれるジェネレータと呼ばれる関数の様な物が入っています
これはデータではないので、print(splitter[0])とかやっても上手く行きません。(テッキリ配列か何かだと思ってました・・・

ちなみにKFold(n_splits = 5,shuffle = True)と指定すれば
順番ではなく、ランダムにインデックスを抜き出してくれます

foldMethod.py
kf = KFold(n_splits = 5,shuffle = True)
splitter = kf.split(env_data)

for train,test in splitter:
    print(test)
"""
[  0   6  38  48  53  57  63  65  66  69  70  76  83  84  94  98 104 108
 114 116 118 127 136 140 143 146 151 157 159 161 163 173 177 191 201 207
 210 215 218 224 228 231 234 246 250 251 252 254 255 257 258 259 261 267
 268 271 273 276 279 286 288 301 309 320 321 324 329 333 338 354 359 361
 367 369 371 372 374 376 378 386 391 392 396 414 419 420 421 427 433 435
 444 445 461 470 479 480 481 482 484 496 499 500]...
"""

ランダムにインデックスが抽出されていますね。
前述した通り、ランダムでも順番でもデータの内容に拠らず規則的に分割するのがKfoldです

##Stratified Kfold
これは目的変数の偏りがないように、データを分割します
具体的に言いますと、例えば身長体重という説明変数から、男性女性かという目的変数を予測するモデルを作ろうとします。
100件のデータの内、1~50件は男性のデータ、51~100は女性のデータだとします。
これを例えばKfoldで単純に5個に割ると

データ インデックス 男女比
テストデータ 1~20 (男性20:女性0)
学習データ 21~100 男性30:女性50)

といった偏った分け方をされてしまいます。
Stratified Kfoldを使えば、5分割した時に

データ インデックス 男女比
テストデータ 1~10 & 51~60 (男性10:女性10)
学習データ 11~50 & 61~100 男性40:女性40)

といった感じで、目的変数の偏り無くデータを分割する事が出来ます
では実際に実装を見てみます

foldMethod.py
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits = 5)
splitter = skf.split(x,y)
for train,test in splitter:
    print(test)
#[ 0  1  2  3  4  5  6  7  8  9 50 51 52 53 54 55 56 57 58 59]...

ここではxには説明変数がyには男女を表す1か0が入ってます。
またKFoldと同様にskf = StratifiedKFold(n_splits = 5,shuffle = True)と指定すれば
インデックスは目的変数の比率を保ったまま、ランダムに抽出されます。

ただ、このStratified Kfoldは、分類問題にしか使えません
なので入力される目的変数が**[0 or 1]** や**[0 or 1 or 2]のようなカテゴリカルなデータでなければいけません。
もし
[21.4 , 48.1 , 29.5 ,,,]**みたいなデータを入力すると以下のようなエラーが出ます。
Supported target types are: ('binary', 'multiclass'). Got 'continuous' instead.

##Group Kfold
これはサンプルデータを任意にグルーピングしてデータを分割する方法です。
例えば、説明変数が顔写真、目的変数が男性か女性かという分類問題があったとします。
また同じ人の写真が、表情違い等で複数枚サンプルデータに存在するとします。

:stuck_out_tongue_closed_eyes::grin::frowning2::unamused:

この時にKfoldStratified Kfoldを使用してデータを分割すると

学習データ テストデータ
:stuck_out_tongue_closed_eyes::grin::frowning2: :unamused:

こんな感じで同じ人間の顔写真データが学習用とテスト用のデータに分割されます。
この時、このモデルが全く未知の人間の性別を判別する目的で作られているとすると
学習に使ったデータと同一人物のデータで精度検証を行う事になるので
実際よりも良い結果が出てしまう可能性があります。

そこで、サンプルデータをグループ分けして、目的変数と説明変数に加えてグループを表すデータを追加して
同じグループのデータがバラけないようにするのが、Group Kfoldという分割方法です。
KfoldStratified Kfoldでは説明変数と目的変数だけで良かったですが
今回はデータの種類を表すグループデータが追加で必要になります。

image.png

image.png

foldMethod.py
"""
これがグループを表すデータラベル
Group = [0,0,0,0,・・・9,9,9,9] 0~9が10個ずつ連続で入っています
"""

from sklearn.model_selection import GroupKFold

gkf = GroupKFold(n_splits = 5)
splitter = gkf.split(X,Y,group)

for train,test in splitter:
    print(test)

#抽出されたインデックス
#[40 41 42 43 44 45 46 47 48 49 90 91 92 93 94 95 96 97 98 99]...

40~49はグループラベル4のグループ
90~99はグループラベル9のグループです

ラベル4と9のグループのデータが全てテストデータになっており
同じグループラベルのデータが学習データとテストデータに分かれていない事が分かります。

##まとめ
以上、メジャーだと言われているデータ分割方法3種を学びました。
交差検証はモデルの性能検証なので、より正確にモデルの性能を評価する為にも
サンプルデータに適した分割方法を採用する必要がありそうです。
scikit-learnには他にも複数のデータ分割方法が実装されているようです
公式サイトにリファレンスがあるので、気になる方はどうぞ

##次回へ
今回は交差検証を行う際の、データ分割方法のみにフォーカスしました。
次回は、これを踏まえて現在取り組んでいるのボストン住宅価格を予測するモデル作りに戻ります。

##お願い
機械学習の初心者が、学んだ知識の確認と備忘用に投稿しています。
間違っている部分や、何かお気づきの事がありましたらご指摘頂けますと幸いです
あと、独学者なので「私も今勉強中なんです!」みたいな人がいればコメント貰えると喜びます。

32
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?