2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ランダムにGroupKFoldする

Posted at

GroupKFoldには、他のKFold系と違ってrandom_stateを指定する引数がないが、アンサンブルなどでランダムにGroupKFoldしたいことがある。

環境
import numpy as np
from sklearn.model_selection import GroupKFold, KFold

NG例:シャッフルして渡す

さっそくダメな例から。

NG例
gkf = GroupKFold(n_splits=3)
group = np.array([1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6])
for i in range(3):
    print(i, "回目", group)
    for t,v in gkf.split(group,group, group):
        print(set(group[t]), set(group[v])) 
    np.random.shuffle(group)

0 回目 [1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6]
{1, 2, 4, 5} {3, 6}
{1, 3, 4, 6} {2, 5}
{2, 3, 5, 6} {1, 4}
1 回目 [5 4 3 4 4 6 1 6 1 3 6 5 4 2 1 6 5 2 3 3 2 2 5 1]
{1, 2, 4, 5} {3, 6}
{1, 3, 4, 6} {2, 5}
{2, 3, 5, 6} {1, 4}
2 回目 [1 2 6 3 1 2 3 1 6 3 1 2 4 3 6 5 5 4 5 4 4 2 5 6]
{1, 2, 4, 5} {3, 6}
{1, 3, 4, 6} {2, 5}
{2, 3, 5, 6} {1, 4}

渡すラベルの順序が異なるのに、全部同じ結果が返ってきた。
ソース読んでないが、Foldサイズをなるべく均等に分けようとするので、各ラベルの個数を数えるときに内部でソートされるのだと思われる。
今回はどのラベルも同じサイズにもかかわらず。

OK例1:ランダムにマッピングして渡す

GroupKFoldに渡すときだけ、別の値にマッピングして渡すようにする。
ループごとで毎回そのマッピングを変えることで、毎回異なる結果を得る。

OK例
gkf = GroupKFold(n_splits=3)
group = np.array([1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6])
mapping = np.arange(group.max()+1)
for i in range(3):
    print(i, "回目", mapping[group])
    for t,v in gkf.split(group,group, mapping[group]):
        print(set(group[t]), set(group[v])) 
    np.random.shuffle(mapping)

0 回目 [1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6]
{1, 2, 4, 5} {3, 6}
{1, 3, 4, 6} {2, 5}
{2, 3, 5, 6} {1, 4}
1 回目 [0 0 0 0 4 4 4 4 3 3 3 3 6 6 6 6 1 1 1 1 5 5 5 5]
{1, 2, 5, 6} {3, 4}
{1, 2, 3, 4} {5, 6}
{3, 4, 5, 6} {1, 2}
2 回目 [1 1 1 1 3 3 3 3 4 4 4 4 2 2 2 2 6 6 6 6 5 5 5 5]
{1, 3, 4, 6} {2, 5}
{1, 2, 3, 5} {4, 6}
{2, 4, 5, 6} {1, 3}

毎回違う結果が返ってきた。
これで十分だが、レベルのサイズに偏りがある場合は、「なるべく均等にしようとする」せいで、結局同じ分け方が返ってくる恐れはある。

OK例2:ラベル集合をKFoldする

ラベルの集合をランダムにKFoldする。

OK例
def my_gkf(group, n_splits, random_state):
    unique_group = np.unique(group)
    kf = KFold(n_splits=n_splits, random_state=random_state, shuffle=True)
    for t,v in kf.split(unique_group, unique_group):
        t_group = unique_group[t]
        v_group = unique_group[v]
        yield np.where(np.isin(group, t_group))[0], np.where(np.isin(group, v_group))[0]

group = np.array([1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6])
for i in range(3):
    print(i, "回目")
    for t,v in my_gkf(group, 3, i):
        print(set(group[t]), set(group[v])) 

0 回目
{1, 2, 4, 5} {3, 6}
{1, 3, 5, 6} {2, 4}
{2, 3, 4, 6} {1, 5}
1 回目
{1, 4, 5, 6} {2, 3}
{2, 3, 4, 6} {1, 5}
{1, 2, 3, 5} {4, 6}
2 回目
{1, 3, 4, 6} {2, 5}
{1, 2, 5, 6} {3, 4}
{2, 3, 4, 5} {1, 6}

毎回違う結果が返ってきた。
偏りを特に気にせずやる場合はこれがよさそう。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?