LoginSignup
5
11

More than 5 years have passed since last update.

ニューラルネットワークのハイパーパラメータをランダムに探索してみる

Last updated at Posted at 2017-03-05

はじめに

機械学習の中でもDeep Learningはハイパーパラメータの数が多く、精度を高めるためにはパラメータ調整の試行錯誤が必要になります。この最適なモデルを見つけるための試行錯誤を、自動化できないかと思い調べてみたところ、大きく二つの方法があります。

  • グリッドサーチ:総当たりでパラメータを探索する
  • ランダムサーチ:パラメータの組み合わせをランダムに探索する

グリッドサーチについては、下記の記事に詳しく書かれていますので、

この記事では、ランダムサーチについて試してみました。ランダムサーチ自体の解説は、「ゼロから作るDeep Learning」の 6.5.2章にわかりやすく説明されていますので参照ください。

ランダムサーチ

日本語の記事があまり見つからなかったので、こちらの記事を参考に試してみました。
Comparing randomized search and grid search for hyperparameter estimation
(http://scikit-learn.org/stable/auto_examples/model_selection/randomized_search.html)

基本的に実装方法はグリッドサーチと同じですが、何パターンまで調べるか (n_iter_search) の指定が追加になります。

from sklearn.model_selection import RandomizedSearchCV
...
model = KerasClassifier(build_fn=tr_model, verbose=0)
param_grid = dict(activation=activation, 
                  optimizer=optimizer, 
                  dout=dout,
                  init_mode=init_mode,
                  out_dim1=out_dim1, 
                  out_dim2=out_dim2, 
                  out_dim3=out_dim3, 
                  out_dim4=out_dim4, 
                  out_dim5=out_dim5, 
                  nb_epoch=nb_epoch, 
                  batch_size=batch_size) 
n_iter_search = 5
random_search = RandomizedSearchCV(estimator=model, param_distributions=param_grid, 
                                   n_iter=n_iter_search, cv=4, n_jobs=1, verbose=2)
random_result=random_search.fit(train_x, y_train)

結果の表示は、scikit-learnのscore_reportがわかりやすいと思います。

from sklearn.metrics import classification_report
...
score_report(random_search.cv_results_, 10)
Model with rank: 1
Mean validation score: 0.701 (std: 0.011)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'he_normal', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adagrad', 'out_dim2': 40}

Model with rank: 2
Mean validation score: 0.701 (std: 0.015)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'he_normal', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adam', 'out_dim2': 40}

Model with rank: 3
Mean validation score: 0.677 (std: 0.039)
Parameters: {'dout': 0.1, 'out_dim4': 10, 'out_dim3': 30, 'nb_epoch': 10, 'out_dim5': 20, 'activation': 'tanh', 'init_mode': 'glorot_uniform', 'batch_size': 500, 'out_dim1': 50, 'optimizer': 'Adam', 'out_dim2': 40}

....

探索した中のベストのモデルで、テストデータセットを使った結果も見ることもできます。

# テストデータ
y_true, y_pred = test_y, random_search.predict(test_x)
print(classification_report(y_true, y_pred))
             precision    recall  f1-score   support

          0       0.81      0.97      0.88       102
          1       0.95      0.99      0.97       109
          2       0.92      0.95      0.93       217
          3       0.54      0.38      0.45       193
          4       0.48      0.85      0.61       214
          5       0.99      0.97      0.98       202
          6       1.00      0.96      0.98       180
          7       0.00      0.00      0.00       196
          8       0.49      1.00      0.66       186
          9       1.00      0.99      0.99       208
         10       0.93      0.99      0.96       184
         11       0.76      0.38      0.51       204
         12       0.78      0.66      0.72       205

avg / total       0.73      0.76      0.73      2400
5
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
11