Python
機械学習
DeepLearning
Chainer
Optuna

賢いパラメータ探索: Optuna入門 with Chainer

Preferred Networksのパラメータ探索ライブラリOptunaを使ってみたら便利だったので。

はじめに

機械学習を上手く使うには、ハイパーパラメータの設定が重要になります。
しかし、手動でいちいちパラメータを探すのは膨大な労力を必要とします。
また、グリッドサーチにより探索できるパラメータ数にも限界があります。

賢く; すなわち短時間で良質なパラメータを発見するための自動化を導入します。
Optunaを使うことで、簡単に自動化されたパラメータ探索ができます。

Optunaとは

ハイパーパラメータの最適化フレームワーク

主要機能:
1. Define-by-Run
既存のフレームワークは検索空間と目的関数を別々に定義します。 Optunaは、サーチスペースは目的関数内で定義され、すべてのハイパーパラメータは実行時に定義されます。この機能により、Optunaで書かれたコードはよりモジュール化され修正しやすくなります。
2. Parallel distributed optimization
Optunaは、線形のスケーラビリティで並列化することができます。並列化を設定するには、ユーザーは複数の最適化プロセスを実行するだけで、Optunaは自動的に検証結果をバックグラウンドで共有します。
3. Pruning of unpromising trials
枝狩り機能は、トレーニングの初期段階で自動的に見込みのない試験を停止します。 Optunaは枝狩りを簡単に実装するためのインターフェースを提供します。
(出典: https://optuna.org)

要約すると、1.賢いパラメータ探索が簡単に導入できて、2.複数サーバーでの分散探索が簡単にスケールして、3.Chainerの学習中に見込みの薄い探索を自動で中断してくれるフレームワークです。読みやすくシンプルなAPIなので、Chainer以外のあらゆるパラメータ探索にも使うことができます。

サンプルプログラム

全文はこちらのGistに公開しています。

tune-mnist.py(一部抜粋)
import optuna
from optuna.integration import ChainerPruningExtension


# ここで探索空間を定義しつつ、探索自体も記述します
def objective(trial):
    # 学習率だけでなく、バッチサイズやモデル自体の定義も探索できます
    n_mid_units = trial.suggest_int('n_mid_units', 10, 1000)
    batchsize = trial.suggest_int('batchsize', 16, 512)
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    momentum = trial.suggest_uniform('momentum', 0, 1)

    # Chainerの学習記述を分離して、trainerのみを管理すると便利です
    trainer = get_trainer(gpu, n_mid_units, batchsize, lr, momentum)

    # ここで枝狩りの処理を挿入しています
    # この例ではvalidation/main/lossを1epochごとに監視します
    trainer.extend(ChainerPruningExtension(
        trial, 'validation/main/loss', (1, 'epoch')))

    trainer.run()
    accuracy = trainer.observation['validation/main/accuracy']

    # 最終的な評価指数は1-accuracyすなわち誤差です
    return 1.0 - float(accuracy)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    args = parser.parse_args()

    gpu = args.gpu  # note: global variable

    # 枝狩りの開始条件を2epoch後に設定します
    pruner = optuna.pruners.MedianPruner(n_warmup_steps=2)  # after 2epoch

    # sqliteが作成されます
    # dbをファイルを共有した、複数のサーバーで協調して探索を行えます
    study = optuna.study.create_study(storage='sqlite:///example.db',
        pruner=pruner, study_name='mnist', load_if_exists=True)
    study.optimize(objective, n_trials=100)

試してみる

$ pip install chainer numpy cupy optuna
$ git clone https://gist.github.com/21b51db70252b7415d46069e464b2d65.git optuna
$ cd optuna
$ python tune-mnist.py

学習を進めるとたまに発生するTracebackがエラーのように見えるかもしれません。これは、枝狩りが成功した時の正しい挙動です。

まとめ

Chainer使ってるならOptuna便利そうです。
他に便利な機能などあればコメントで教えてください!