7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Scikit-learnを用いた簡単なグリッドサーチ テンプレート

Posted at

#Scikit-learnを用いたグリッドサーチ
この記事では、scikit-learn(Python)を用いた簡単なグリッドサーチを行います。
毎回調べるのが面倒なので、テンプレにしました。

##グリッドサーチ
グリッドサーチとは:

今回は、scikit-learnのGridSearchCVを用いてグリッドサーチします。
公式ページ:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html

という方は、以下のページを参照ください。

クロスバリデーションとグリッドサーチ:
https://qiita.com/Takayoshi_Makabe/items/d35eed0c3064b495a08b

###使用するライブラリ
今回は、回帰問題を想定してグリッドサーチします。

from sklearn.metrics import mean_absolute_error #MAE
from sklearn.metrics import mean_squared_error #MSE
from sklearn.metrics import make_scorer

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold

###RMSE
RMSEは、scikit-learnのパッケージにないので、自分で関数を定義します。

def rmse(y_true,y_pred):
    #RMSEを算出
    rmse = np.sqrt(mean_squared_error(y_true,y_pred))
    print('rmse',rmse)
    return rmse

###K Fold

kf = KFold(n_splits=5,shuffle=True,random_state=0)

###Linear SVR
線形サポートベクトルを行う場合、SVRを使うより、LinearSVRを使った方が、早いらしい。

from sklearn.svm import LinearSVR

params_cnt = 10
max_iter = 1000

params = {"C":np.logspace(0,1,params_cnt), "epsilon":np.logspace(-1,1,params_cnt)}
'''
epsilon : Epsilon parameter in the epsilon-insensitive loss function.
          Note that the value of this parameter depends on the scale of the target variable y.
          If unsure, set epsilon=0.
C : Regularization parameter.
    The strength of the regularization is inversely proportional to C.
    Must be strictly positive.
https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVR.html
'''
gridsearch = GridSearchCV(
 LinearSVR(max_iter=max_iter,random_state=0),
 params,
 cv=kf,
 scoring=make_scorer(rmse,greater_is_better=False),
 return_train_score=True,
 n_jobs=-1
 )

gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)

LSVR = LinearSVR(max_iter=max_iter,random_state=0,C=gridsearch.best_params_["C"], epsilon=gridsearch.best_params_["epsilon"])

###Kernel SVR

from sklearn.svm import SVR

params_cnt = 10
params = {"kernel":['rbf'],
 "C":np.logspace(0,1,params_cnt),
 "epsilon":np.logspace(-1,1,params_cnt)}

gridsearch = GridSearchCV(
 SVR(gamma='auto'),
 params, cv=kf,
 scoring=make_scorer(rmse,greater_is_better=False),
 n_jobs=-1
 )
'''
epsilon : Epsilon parameter in the epsilon-insensitive loss function.
          Note that the value of this parameter depends on the scale of the target variable y.
          If unsure, set epsilon=0.
C : Regularization parameter.
    The strength of the regularization is inversely proportional to C.
    Must be strictly positive.
https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html
'''

gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)
print()

KSVR =SVR(
 kernel=gridsearch.best_params_['kernel'],
 C=gridsearch.best_params_["C"],
 epsilon=gridsearch.best_params_["epsilon"]
)

###RandomForest
ランダムフォレストは、あんまりハイパーパラメータ チューニングしなくても良いやつなので、
そんなに意味ないかもしれないが、作っちゃったので、載せます。

from sklearn.ensemble import RandomForestRegressor

params = {
 "max_depth":[2,5,10],
 "n_estimators":[10,20,30,40,50] n_estimatorsは、大きいほど精度が上がるので、時間がある時は、大きくすべき。だたし時間がかかる
 }
gridsearch = GridSearchCV(
 RandomForestRegressor(random_state=0),
 params,
 cv=kf,
 scoring=make_scorer(rmse,greater_is_better=False),
 n_jobs=-1
 )
'''
n_estimators : The number of trees in the forest.
max_depth : The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
'''
gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)
print()
RF = RandomForestRegressor(random_state=0,n_estimators=gridsearch.best_params_["n_estimators"], max_depth=gridsearch.best_params_["max_depth"])

##最後に
GridSearchCVを使えば、数行でチューニングできるので便利です。
今回は、3つのモデルで作りましたが、他のモデルももちろん、できます。

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?