30
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

賢いパラメータ探索: Optuna入門 with Chainer

Last updated at Posted at 2019-01-10

Preferred Networksのパラメータ探索ライブラリOptunaを使ってみたら便利だったので。

はじめに

機械学習を上手く使うには、ハイパーパラメータの設定が重要になります。
しかし、手動でいちいちパラメータを探すのは膨大な労力を必要とします。
また、グリッドサーチにより探索できるパラメータ数にも限界があります。

賢く; すなわち短時間で良質なパラメータを発見するための自動化を導入します。
Optunaを使うことで、簡単に自動化されたパラメータ探索ができます。

Optunaとは

ハイパーパラメータの最適化フレームワーク

主要機能:

  1. Define-by-Run
    既存のフレームワークは検索空間と目的関数を別々に定義します。 Optunaは、サーチスペースは目的関数内で定義され、すべてのハイパーパラメータは実行時に定義されます。この機能により、Optunaで書かれたコードはよりモジュール化され修正しやすくなります。
  2. Parallel distributed optimization
    Optunaは、線形のスケーラビリティで並列化することができます。並列化を設定するには、ユーザーは複数の最適化プロセスを実行するだけで、Optunaは自動的に検証結果をバックグラウンドで共有します。
  3. Pruning of unpromising trials
    枝狩り機能は、トレーニングの初期段階で自動的に見込みのない試験を停止します。 Optunaは枝狩りを簡単に実装するためのインターフェースを提供します。
    (出典: https://optuna.org)

要約すると、1.賢いパラメータ探索が簡単に導入できて、2.複数サーバーでの分散探索が簡単にスケールして、3.Chainerの学習中に見込みの薄い探索を自動で中断してくれるフレームワークです。読みやすくシンプルなAPIなので、Chainer以外のあらゆるパラメータ探索にも使うことができます。

サンプルプログラム

全文はこちらのGistに公開しています。

tune-mnist.py(一部抜粋)
import optuna
from optuna.integration import ChainerPruningExtension


# ここで探索空間を定義しつつ、探索自体も記述します
def objective(trial):
    # 学習率だけでなく、バッチサイズやモデル自体の定義も探索できます
    n_mid_units = trial.suggest_int('n_mid_units', 10, 1000)
    batchsize = trial.suggest_int('batchsize', 16, 512)
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    momentum = trial.suggest_uniform('momentum', 0, 1)

    # Chainerの学習記述を分離して、trainerのみを管理すると便利です
    trainer = get_trainer(gpu, n_mid_units, batchsize, lr, momentum)

    # ここで枝狩りの処理を挿入しています
    # この例ではvalidation/main/lossを1epochごとに監視します
    trainer.extend(ChainerPruningExtension(
        trial, 'validation/main/loss', (1, 'epoch')))

    trainer.run()
    accuracy = trainer.observation['validation/main/accuracy']

    # 最終的な評価指数は1-accuracyすなわち誤差です
    return 1.0 - float(accuracy)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    args = parser.parse_args()

    gpu = args.gpu  # note: global variable

    # 枝狩りの開始条件を2epoch後に設定します
    pruner = optuna.pruners.MedianPruner(n_warmup_steps=2)  # after 2epoch

    # sqliteが作成されます
    # dbをファイルを共有した、複数のサーバーで協調して探索を行えます
    study = optuna.study.create_study(storage='sqlite:///example.db',
        pruner=pruner, study_name='mnist', load_if_exists=True)
    study.optimize(objective, n_trials=100)

試してみる

$ pip install chainer numpy cupy optuna
$ git clone https://gist.github.com/21b51db70252b7415d46069e464b2d65.git optuna
$ cd optuna
$ python tune-mnist.py

学習を進めるとたまに発生するTracebackがエラーのように見えるかもしれません。これは、枝狩りが成功した時の正しい挙動です。

まとめ

Chainer使ってるならOptuna便利そうです。
他に便利な機能などあればコメントで教えてください!

30
40
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?