Edited at

賢いパラメータ探索: Optuna入門 with Chainer

Preferred Networksのパラメータ探索ライブラリOptunaを使ってみたら便利だったので。


はじめに

機械学習を上手く使うには、ハイパーパラメータの設定が重要になります。

しかし、手動でいちいちパラメータを探すのは膨大な労力を必要とします。

また、グリッドサーチにより探索できるパラメータ数にも限界があります。

賢く; すなわち短時間で良質なパラメータを発見するための自動化を導入します。

Optunaを使うことで、簡単に自動化されたパラメータ探索ができます。


Optunaとは


ハイパーパラメータの最適化フレームワーク

主要機能:

1. Define-by-Run

既存のフレームワークは検索空間と目的関数を別々に定義します。 Optunaは、サーチスペースは目的関数内で定義され、すべてのハイパーパラメータは実行時に定義されます。この機能により、Optunaで書かれたコードはよりモジュール化され修正しやすくなります。

2. Parallel distributed optimization

Optunaは、線形のスケーラビリティで並列化することができます。並列化を設定するには、ユーザーは複数の最適化プロセスを実行するだけで、Optunaは自動的に検証結果をバックグラウンドで共有します。

3. Pruning of unpromising trials

枝狩り機能は、トレーニングの初期段階で自動的に見込みのない試験を停止します。 Optunaは枝狩りを簡単に実装するためのインターフェースを提供します。

(出典: https://optuna.org)


要約すると、1.賢いパラメータ探索が簡単に導入できて、2.複数サーバーでの分散探索が簡単にスケールして、3.Chainerの学習中に見込みの薄い探索を自動で中断してくれるフレームワークです。読みやすくシンプルなAPIなので、Chainer以外のあらゆるパラメータ探索にも使うことができます。


サンプルプログラム

全文はこちらのGistに公開しています。


tune-mnist.py(一部抜粋)

import optuna

from optuna.integration import ChainerPruningExtension

# ここで探索空間を定義しつつ、探索自体も記述します
def objective(trial):
# 学習率だけでなく、バッチサイズやモデル自体の定義も探索できます
n_mid_units = trial.suggest_int('n_mid_units', 10, 1000)
batchsize = trial.suggest_int('batchsize', 16, 512)
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
momentum = trial.suggest_uniform('momentum', 0, 1)

# Chainerの学習記述を分離して、trainerのみを管理すると便利です
trainer = get_trainer(gpu, n_mid_units, batchsize, lr, momentum)

# ここで枝狩りの処理を挿入しています
# この例ではvalidation/main/lossを1epochごとに監視します
trainer.extend(ChainerPruningExtension(
trial, 'validation/main/loss', (1, 'epoch')))

trainer.run()
accuracy = trainer.observation['validation/main/accuracy']

# 最終的な評価指数は1-accuracyすなわち誤差です
return 1.0 - float(accuracy)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--gpu', '-g', type=int, default=0)
args = parser.parse_args()

gpu = args.gpu # note: global variable

# 枝狩りの開始条件を2epoch後に設定します
pruner = optuna.pruners.MedianPruner(n_warmup_steps=2) # after 2epoch

# sqliteが作成されます
# dbをファイルを共有した、複数のサーバーで協調して探索を行えます
study = optuna.study.create_study(storage='sqlite:///example.db',
pruner=pruner, study_name='mnist', load_if_exists=True)
study.optimize(objective, n_trials=100)



試してみる

$ pip install chainer numpy cupy optuna

$ git clone https://gist.github.com/21b51db70252b7415d46069e464b2d65.git optuna
$ cd optuna
$ python tune-mnist.py

学習を進めるとたまに発生するTracebackがエラーのように見えるかもしれません。これは、枝狩りが成功した時の正しい挙動です。


まとめ

Chainer使ってるならOptuna便利そうです。

他に便利な機能などあればコメントで教えてください!