search
LoginSignup
11

More than 5 years have passed since last update.

posted at

updated at

ニューラルネットワークのハイパーパラメータをランダムに探索してみる

はじめに

機械学習の中でもDeep Learningはハイパーパラメータの数が多く、精度を高めるためにはパラメータ調整の試行錯誤が必要になります。この最適なモデルを見つけるための試行錯誤を、自動化できないかと思い調べてみたところ、大きく二つの方法があります。

  • グリッドサーチ:総当たりでパラメータを探索する
  • ランダムサーチ:パラメータの組み合わせをランダムに探索する

グリッドサーチについては、下記の記事に詳しく書かれていますので、

この記事では、ランダムサーチについて試してみました。ランダムサーチ自体の解説は、「ゼロから作るDeep Learning」の 6.5.2章にわかりやすく説明されていますので参照ください。

ランダムサーチ

日本語の記事があまり見つからなかったので、こちらの記事を参考に試してみました。
Comparing randomized search and grid search for hyperparameter estimation
(http://scikit-learn.org/stable/auto_examples/model_selection/randomized_search.html)

基本的に実装方法はグリッドサーチと同じですが、何パターンまで調べるか (n_iter_search) の指定が追加になります。

from sklearn.model_selection import RandomizedSearchCV
...
model = KerasClassifier(build_fn=tr_model, verbose=0)
param_grid = dict(activation=activation, 
                  optimizer=optimizer, 
                  dout=dout,
                  init_mode=init_mode,
                  out_dim1=out_dim1, 
                  out_dim2=out_dim2, 
                  out_dim3=out_dim3, 
                  out_dim4=out_dim4, 
                  out_dim5=out_dim5, 
                  nb_epoch=nb_epoch, 
                  batch_size=batch_size) 
n_iter_search = 5
random_search = RandomizedSearchCV(estimator=model, param_distributions=param_grid, 
                                   n_iter=n_iter_search, cv=4, n_jobs=1, verbose=2)
random_result=random_search.fit(train_x, y_train)

結果の表示は、scikit-learnのscore_reportがわかりやすいと思います。

from sklearn.metrics import classification_report
...
score_report(random_search.cv_results_, 10)
Model with rank: 1
Mean validation score: 0.701 (std: 0.011)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'he_normal', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adagrad', 'out_dim2': 40}

Model with rank: 2
Mean validation score: 0.701 (std: 0.015)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'he_normal', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adam', 'out_dim2': 40}

Model with rank: 3
Mean validation score: 0.677 (std: 0.039)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'glorot_uniform', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adam', 'out_dim2': 40}

....

探索した中のベストのモデルで、テストデータセットを使った結果も見ることもできます。

# テストデータ
y_true, y_pred = test_y, random_search.predict(test_x)
print(classification_report(y_true, y_pred))
             precision    recall  f1-score   support

          0       0.81      0.97      0.88       102
          1       0.95      0.99      0.97       109
          2       0.92      0.95      0.93       217
          3       0.54      0.38      0.45       193
          4       0.48      0.85      0.61       214
          5       0.99      0.97      0.98       202
          6       1.00      0.96      0.98       180
          7       0.00      0.00      0.00       196
          8       0.49      1.00      0.66       186
          9       1.00      0.99      0.99       208
         10       0.93      0.99      0.96       184
         11       0.76      0.38      0.51       204
         12       0.78      0.66      0.72       205

avg / total       0.73      0.76      0.73      2400

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
11