Preferred Networksのパラメータ探索ライブラリOptunaを使ってみたら便利だったので。
はじめに
機械学習を上手く使うには、ハイパーパラメータの設定が重要になります。
しかし、手動でいちいちパラメータを探すのは膨大な労力を必要とします。
また、グリッドサーチにより探索できるパラメータ数にも限界があります。
賢く; すなわち短時間で良質なパラメータを発見するための自動化を導入します。
Optunaを使うことで、簡単に自動化されたパラメータ探索ができます。
Optunaとは
ハイパーパラメータの最適化フレームワーク
主要機能:
- Define-by-Run
既存のフレームワークは検索空間と目的関数を別々に定義します。 Optunaは、サーチスペースは目的関数内で定義され、すべてのハイパーパラメータは実行時に定義されます。この機能により、Optunaで書かれたコードはよりモジュール化され修正しやすくなります。- Parallel distributed optimization
Optunaは、線形のスケーラビリティで並列化することができます。並列化を設定するには、ユーザーは複数の最適化プロセスを実行するだけで、Optunaは自動的に検証結果をバックグラウンドで共有します。- Pruning of unpromising trials
枝狩り機能は、トレーニングの初期段階で自動的に見込みのない試験を停止します。 Optunaは枝狩りを簡単に実装するためのインターフェースを提供します。
(出典: https://optuna.org)
要約すると、1.賢いパラメータ探索が簡単に導入できて、2.複数サーバーでの分散探索が簡単にスケールして、3.Chainerの学習中に見込みの薄い探索を自動で中断してくれるフレームワークです。読みやすくシンプルなAPIなので、Chainer以外のあらゆるパラメータ探索にも使うことができます。
サンプルプログラム
全文はこちらのGistに公開しています。
tune-mnist.py(一部抜粋)
import optuna
from optuna.integration import ChainerPruningExtension
# ここで探索空間を定義しつつ、探索自体も記述します
def objective(trial):
# 学習率だけでなく、バッチサイズやモデル自体の定義も探索できます
n_mid_units = trial.suggest_int('n_mid_units', 10, 1000)
batchsize = trial.suggest_int('batchsize', 16, 512)
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
momentum = trial.suggest_uniform('momentum', 0, 1)
# Chainerの学習記述を分離して、trainerのみを管理すると便利です
trainer = get_trainer(gpu, n_mid_units, batchsize, lr, momentum)
# ここで枝狩りの処理を挿入しています
# この例ではvalidation/main/lossを1epochごとに監視します
trainer.extend(ChainerPruningExtension(
trial, 'validation/main/loss', (1, 'epoch')))
trainer.run()
accuracy = trainer.observation['validation/main/accuracy']
# 最終的な評価指数は1-accuracyすなわち誤差です
return 1.0 - float(accuracy)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--gpu', '-g', type=int, default=0)
args = parser.parse_args()
gpu = args.gpu # note: global variable
# 枝狩りの開始条件を2epoch後に設定します
pruner = optuna.pruners.MedianPruner(n_warmup_steps=2) # after 2epoch
# sqliteが作成されます
# dbをファイルを共有した、複数のサーバーで協調して探索を行えます
study = optuna.study.create_study(storage='sqlite:///example.db',
pruner=pruner, study_name='mnist', load_if_exists=True)
study.optimize(objective, n_trials=100)
試してみる
$ pip install chainer numpy cupy optuna
$ git clone https://gist.github.com/21b51db70252b7415d46069e464b2d65.git optuna
$ cd optuna
$ python tune-mnist.py
学習を進めるとたまに発生するTracebackがエラーのように見えるかもしれません。これは、枝狩りが成功した時の正しい挙動です。
まとめ
Chainer使ってるならOptuna便利そうです。
他に便利な機能などあればコメントで教えてください!