39
33

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【機械学習】パラメータチューニングについて

Posted at

今回はパラメータチューニングについてまとめます。

パラメータチューニングとは

パラメータとは機械学習モデルにおける設定値や制限値のこと(ハイパーパラメータとも)。
モデルが最適解を出せるパラメータを走査し、設定することをパラメータチューニングという。

パラメータチューニングの種類

ここでは代表的なチューニング方法として2種についてまとめます。

グリッドサーチ

あらかじめパラメータの候補値を指定し、その候補パラメータを組み合わせて学習を試行することにより最適なパラメータを走査する方法。

ランダムサーチ

パラメータの設定範囲および試行回数を指定し、指定値範囲内から無作為に抽出したパラメータにより学習を試行することにより最適なパラメータを走査する方法。

パラメータチューニングの実践

sklearnのサンプルデータセットを用いて、パラメータチューニングを実践する。
使用モデルはランダムフォレスト。

ランダムフォレストモデルにおける n_estimators パラメータの初期値は10。
そのパラメータについて、

① 10,20,30,...100 までの数値の中で最も優れたパラメータを使用する(グリッドサーチ)
② 1~100 までの数値からランダムに抽出した10の数値の中で最も優れたパラメータを抽出する(ランダムサーチ)

という2種のチューニングを行う。

パラメータチューニングの実践
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score


# データの準備、および分割
dataset = load_breast_cancer()
X = pd.DataFrame(dataset.data, columns=dataset.feature_names)
y = pd.DataFrame(dataset.target, columns=['target'])
train_x, test_x, train_y, test_y = train_test_split(X, y)


# グリッドサーチ(パラメータ候補指定)用のパラメータ10種
paramG = {'n_estimators':[1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]}
# ランダムサーチ(パラメータ範囲指定)用のパラメータ 1~100
paramR = {'n_estimators':np.arange(100)}

# モデル生成。上から順に、通常のランダムフォレスト、グリッドサーチ・ランダムフォレスト、
# ランダムサーチ・ランダムフォレスト。
RFC_raw  = RandomForestClassifier(random_state=0)
RFC_grid = GridSearchCV(estimator=RandomForestClassifier(random_state=0), param_grid=paramG, \
                        scoring='r2', cv=3)
RFC_rand = RandomizedSearchCV(estimator=RandomForestClassifier(random_state=0), param_distributions=paramR, \
                        scoring='r2', cv=3)

# 各モデルに学習を行わせる。
RFC_raw.fit (train_x, train_y.as_matrix().ravel())
RFC_grid.fit(train_x, train_y.as_matrix().ravel())
RFC_rand.fit(train_x, train_y.as_matrix().ravel())

# 各モデルの n_estimators パラメータの確認。
print('通常のランダムフォレストモデルにおける n_estimators         :  %d'  %RFC_raw.n_estimators)
print('グリッドサーチ・ランダムフォレストモデルにおける n_estimators   :  %d'  %RFC_grid.best_estimator_.n_estimators)
print('ランダムサーチ・ランダムフォレストモデルにおける n_estimators  :  %d'  %RFC_rand.best_estimator_.n_estimators)

###出力結果###
通常のランダムフォレストモデルにおける n_estimators           :  10
グリッドサーチランダムフォレストモデルにおける n_estimators     :  100
ランダムサーチランダムフォレストモデルにおける n_estimators    :  66

# 各モデルにより算出される予測値
print('通常のランダムフォレストモデルによる予測値         :  %.3f'  %r2_score(test_y, RFC_raw.predict(test_x)))
print('グリッドサーチ・ランダムフォレストモデルによる予測値   :  %.3f'  %r2_score(test_y, RFC_grid.predict(test_x)))
print('ランダムサーチ・ランダムフォレストモデルによる予測値  :  %.3f'  %r2_score(test_y, RFC_rand.predict(test_x)))

###出力結果###
通常のランダムフォレストモデルによる予測値           :  0.762
グリッドサーチランダムフォレストモデルによる予測値     :  0.881
ランダムサーチランダムフォレストモデルによる予測値    :  0.851

上記プログラムにより、グリッドサーチでは n_estimators=100 が最も優れたパラメータだと判定し、ランダムサーチでは n_estimators=66 が最も優れたパラメータだと判定したことがわかる。

このようにパラメータをチューニングすることで、より優れたモデルを生成することができるようになる。

以上。

39
33
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
33

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?