Qiita Teams that are logged in
You are not logged in to any team

Community
Service
Qiita JobsQiita ZineQiita Blog
3
Help us understand the problem. What is going on with this article?
@taruto1215

# Scikit-learnを用いた簡単なグリッドサーチ　テンプレート

More than 1 year has passed since last update.

# Scikit-learnを用いたグリッドサーチ

この記事では、scikit-learn(Python)を用いた簡単なグリッドサーチを行います。

## グリッドサーチ

グリッドサーチとは：

という方は、以下のページを参照ください。

クロスバリデーションとグリッドサーチ：
https://qiita.com/Takayoshi_Makabe/items/d35eed0c3064b495a08b

### 使用するライブラリ

``````from sklearn.metrics import mean_absolute_error #MAE
from sklearn.metrics import mean_squared_error #MSE
from sklearn.metrics import make_scorer

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
``````

### RMSE

RMSEは、scikit-learnのパッケージにないので、自分で関数を定義します。

``````def rmse(y_true,y_pred):
#RMSEを算出
rmse = np.sqrt(mean_squared_error(y_true,y_pred))
print('rmse',rmse)
return rmse
``````

### K Fold

``````kf = KFold(n_splits=5,shuffle=True,random_state=0)
``````

### Linear SVR

``````from sklearn.svm import LinearSVR

params_cnt = 10
max_iter = 1000

params = {"C":np.logspace(0,1,params_cnt), "epsilon":np.logspace(-1,1,params_cnt)}
'''
epsilon : Epsilon parameter in the epsilon-insensitive loss function.
Note that the value of this parameter depends on the scale of the target variable y.
If unsure, set epsilon=0.
C : Regularization parameter.
The strength of the regularization is inversely proportional to C.
Must be strictly positive.
https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVR.html
'''
gridsearch = GridSearchCV(
LinearSVR(max_iter=max_iter,random_state=0),
params,
cv=kf,
scoring=make_scorer(rmse,greater_is_better=False),
return_train_score=True,
n_jobs=-1
)

gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)

LSVR = LinearSVR(max_iter=max_iter,random_state=0,C=gridsearch.best_params_["C"], epsilon=gridsearch.best_params_["epsilon"])
``````

### Kernel SVR

``````from sklearn.svm import SVR

params_cnt = 10
params = {"kernel":['rbf'],
"C":np.logspace(0,1,params_cnt),
"epsilon":np.logspace(-1,1,params_cnt)}

gridsearch = GridSearchCV(
SVR(gamma='auto'),
params, cv=kf,
scoring=make_scorer(rmse,greater_is_better=False),
n_jobs=-1
)
'''
epsilon : Epsilon parameter in the epsilon-insensitive loss function.
Note that the value of this parameter depends on the scale of the target variable y.
If unsure, set epsilon=0.
C : Regularization parameter.
The strength of the regularization is inversely proportional to C.
Must be strictly positive.
https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html
'''

gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)
print()

KSVR =SVR(
kernel=gridsearch.best_params_['kernel'],
C=gridsearch.best_params_["C"],
epsilon=gridsearch.best_params_["epsilon"]
)
``````

### RandomForest

ランダムフォレストは、あんまりハイパーパラメータ チューニングしなくても良いやつなので、
そんなに意味ないかもしれないが、作っちゃったので、載せます。

``````from sklearn.ensemble import RandomForestRegressor

params = {
"max_depth":[2,5,10],
"n_estimators":[10,20,30,40,50] n_estimatorsは、大きいほど精度が上がるので、時間がある時は、大きくすべき。だたし時間がかかる
}
gridsearch = GridSearchCV(
RandomForestRegressor(random_state=0),
params,
cv=kf,
scoring=make_scorer(rmse,greater_is_better=False),
n_jobs=-1
)
'''
n_estimators : The number of trees in the forest.
max_depth : The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
'''
gridsearch.fit(X_trainval, y_trainval)
print('The best parameter = ',gridsearch.best_params_)
print('RMSE = ',-gridsearch.best_score_)
print()
RF = RandomForestRegressor(random_state=0,n_estimators=gridsearch.best_params_["n_estimators"], max_depth=gridsearch.best_params_["max_depth"])
``````

## 最後に

GridSearchCVを使えば、数行でチューニングできるので便利です。

3
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
1. We will deliver articles that match you
By following users and tags, you can catch up information on technical fields that you are interested in as a whole
2. you can read useful information later efficiently
By "stocking" the articles you like, you can search right away