LoginSignup
28
15

More than 1 year has passed since last update.

sklearnのtrain_test_splitのrandom_stateの必要性

Last updated at Posted at 2020-11-09

train_test_splitとは

機械学習でよく使われる関数としてtrain_test_splitがあります。
これはlistや、numpy.arrayや、pandas.DataFrameなどを、学習用のtrainデータとtestデータに分割してくれる関数です。

random_stateとは

まず、train_test_splitのデフォルトの引数であるshuffle=Trueによってデータを分割する前に、データの行の順番がランダムにされています。そして、random_stateとはこの時のデータのランダムな行の順番を固定する引数です。固定するにはrandom_stateにint型の任意の値を設定します。(0、42など)
random_stateのデフォルトにはNoneが入っており、Noneのままだと、データのランダムな行の順番が固定されず、train_test_splitを再実行すると再実行前と異なるtrainデータ、testデータを返します。

random_stateで行の順番を固定する必要性

testデータが実行する度に異なると、ハイパーパラメータのチューニングが無意味になったり、複数モデルを作って多数決を取る際(アンサンブル法)に予測結果が異様に高い精度を出す現象を引き起こします。

以下に予測結果が異様に高い精度を出す実例を紹介します。

まず実験用のデータフレームを作ります。

import pandas as pd
from sklearn.model_selection import train_test_split
df=pd.DataFrame({"a":[1,2,3,4,5],"b":[1,2,3,4,5]})
print(df)
   a  b
0  1  1
1  2  2
2  3  3
3  4  4
4  5  5

次にtrain_test_splitをします。引数であるtrain_sizeとtest_sizeのどちらか片方にfloat型の任意の値を入れることで行数に対しての割合、int型の任意の値を入れることで行数の内何個であるという個数を指定できます。
片方だけ指定するとtrainデータとtestデータのどちらを指定しているか紛らわしいので、本例のようにデータ数が5個ならばtrain_size=4(個)、test_size=1(個)と両者明示的に指定しておくのがいいかと思います。

train_x, test_x = train_test_split(df, train_size=4, test_size=1)
print(train_x)
print(test_x)
   a  b
1  2  2
2  3  3
0  1  1
3  4  4

random_stateに何も指定しない場合のコード

random_stateに何も指定しない場合(つまりtrain_test_splitのデフォルトであるrandom_state=Noneである場合)train_test_splitを実行する度に違うtrainデータとtestデータを返します。

for i in range(3):
    train_x, test_x = train_test_split(df, train_size=4, test_size=1)
    print()
    print(i,"回目")
    print(train_x)
    print(test_x)
0 回目
   a  b
2  3  3
3  4  4
4  5  5
0  1  1
   a  b
1  2  2

1 回目
   a  b
2  3  3
4  5  5
1  2  2
0  1  1
   a  b
3  4  4

2 回目
   a  b
4  5  5
0  1  1
1  2  2
3  4  4
   a  b
2  3  3

例えば、この3データを3つの異なるモデルにそれぞれ学習させtestデータの予測結果で、多数決を取ろうとする(アンサンブル法)と
asdfasdf.png
この画像のようにモデルごとに学習したtrainデータが異なり、モデルに学習データとして未知の部分がないためtestデータの予測結果の多数決の際に異様に高い精度を出してしまうことが予想されます。これを避けるためにrandom_stateが必要です。

random_stateに42を指定した場合のコード

for i in range(3):
    train_x, test_x= train_test_split(df, train_size=4, test_size=1, random_state=42)
    print()
    print(i,"回目")
    print(train_x)
    print(test_x)
0 回目
   a  b
4  5  5
2  3  3
0  1  1
3  4  4
   a  b
1  2  2

1 回目
   a  b
4  5  5
2  3  3
0  1  1
3  4  4
   a  b
1  2  2

2 回目
   a  b
4  5  5
2  3  3
0  1  1
3  4  4
   a  b
1  2  2

random_stateを42に指定したため、test_xが[2 2]に固定されているように、train_test_splitを再実行しても固定されたランダムなtrainデータが返ってくるので、学習の際に未知のデータを確保でき、予測結果が異様に高い精度となるのを避けることが出来ます。

28
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
28
15