3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

train_test_split() で正解ラベルの割合を変えたくない場合は stratify パラメータを使用する

Posted at

結論

train_test_split()を使用する際、stratifyパラメータを使用すると学習データと評価データで正解ラベルの割合が均一となる。

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

背景

不均衡データを扱う機会があり、学習データと評価データの正解ラベルの割合を揃えたかった。

ライブラリ情報

項目 情報
Python 3.9.7
sklearn 1.1.3

ソースコード

事前準備

import time
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split


# irisのデータセットを準備
data_iris = load_iris()
X = pd.DataFrame(
    data_iris.data, 
    columns=data_iris.feature_names
)

y = data_iris.target

正解ラベルを確認

u, counts = np.unique(y, return_counts=True)

# ラベルの一覧
print(u)
# [0 1 2]

# 各ラベルの件数
print(counts)
# [50 50 50]

各ラベルの割合を確認する関数を定義

def calc(i, y):
    print(str(i+1) + '回目 (' + str(len(y_train)) + '件)')
    print('・Label 0 : ' + str(np.round((y == 0).sum()/len(y), 2)) + ' %')
    print('・Label 1 : ' + str(np.round((y == 1).sum()/len(y), 2)) + ' %')
    print('・Label 2 : ' + str(np.round((y == 2).sum()/len(y), 2)) + ' %')
    print('---')

train_test_split()デフォルトの場合

for i in range(5):
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    calc(i, y_train)

出力

1回目 (112件)
・Label 0 : 0.36 %
・Label 1 : 0.32 %
・Label 2 : 0.32 %
---
2回目 (112件)
・Label 0 : 0.35 %
・Label 1 : 0.32 %
・Label 2 : 0.33 %
---
3回目 (112件)
・Label 0 : 0.38 %
・Label 1 : 0.34 %
・Label 2 : 0.29 %
---
4回目 (112件)
・Label 0 : 0.37 %
・Label 1 : 0.31 %
・Label 2 : 0.32 %
---
5回目 (112件)
・Label 0 : 0.34 %
・Label 1 : 0.34 %
・Label 2 : 0.32 %
---

正解ラベルの割合にばらつきがある。

train_test_split() stratifyを使用した場合

for i in range(5):
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
    calc(i, y_train)

出力

1回目 (112件)
・Label 0 : 0.33 %
・Label 1 : 0.34 %
・Label 2 : 0.33 %
---
2回目 (112件)
・Label 0 : 0.33 %
・Label 1 : 0.33 %
・Label 2 : 0.34 %
---
3回目 (112件)
・Label 0 : 0.33 %
・Label 1 : 0.34 %
・Label 2 : 0.33 %
---
4回目 (112件)
・Label 0 : 0.34 %
・Label 1 : 0.33 %
・Label 2 : 0.33 %
---
5回目 (112件)
・Label 0 : 0.34 %
・Label 1 : 0.33 %
・Label 2 : 0.33 %
---

正解ラベルの割合が均一であることが確認できる。

参考サイト

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?