Help us understand the problem. What is going on with this article?

Chainer+OptunaでFashion MNISTの正解率を90%以上にしたい

Optunaとは

ハイパーパラメータの自動最適化フレームワークです。
Pythonで利用できます。
Optunaには3つの特徴があります。
① Define by RunスタイルのAPI
② 学習曲線を用いた試行の枝刈り
③ 並列分散最適化
要はコードが簡単に書け、改修もしやすく、速いってことです。
詳しくは公式ページで確認できます。

今回やること

Fashion MNISTのテスト正解率が90%以上となるハイパーパラメータを、Optunaにて探索します。
Qiita記事「ChainerでFashion mnistのテスト精度を90%以上にする」を参考にしました。

実験条件

実験条件の縛りは上記記事と同様です。

  • データセットはFashion MNIST(MNISTの衣類版。サイズ、枚数はMNISTと同じ。訓練50,000枚、テスト10,000枚。)
  • Google Colaboratory上(GPU)で実行
  • エポック数10以下
  • 訓練時間100秒以内
  • ネットワークは全結合層のみ

その他の確定条件は以下のようにします。

  • 活性化関数はRelu関数
  • 5エポック後、学習率に1/10をかける
  • 試行回数は100

自動最適化を行うパラメータは以下のものとします。

  • ネットワークの層数:1~6
  • 各ユニットのノード数:100~4000
  • オプティマイザ:MomentumSGD, Adam
  • 学習率:1e-5~1e-1
  • バッチサイズ:50, 100, 200, 500

また、PRUNER_INTERVALを3とします。
つまり、3エポック毎に学習曲線をチェックし、見込みがなければその試行を中止します。

コード

コードは以下の通りです。
ほぼ、Optunaの公式exampleのコピーです。

import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
import optuna
from time import time


EPOCH = 10
PRUNER_INTERVAL = 3


def create_model(trial):
    n_layers = trial.suggest_int('n_layers', 1, 6)

    layers = []
    for i in range(n_layers):
        n_units = int(trial.suggest_loguniform(
                'n_units_l{}'.format(i), 100, 4000))
        layers.append(L.Linear(None, n_units))
        layers.append(F.relu)
    layers.append(L.Linear(None, 10))

    return chainer.Sequential(*layers)


def create_optimizer(trial, model):
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'MomentumSGD'])
    if optimizer_name == 'Adam':
        adam_alpha = trial.suggest_loguniform('adam_alpha', 1e-5, 1e-1)
        optimizer = chainer.optimizers.Adam(alpha=adam_alpha)
    else:
        momentum_sgd_lr = trial.suggest_loguniform('momentum_sgd_lr', 1e-5, 1e-1)
        optimizer = chainer.optimizers.MomentumSGD(lr=momentum_sgd_lr)

    optimizer.setup(model)
    return optimizer


def objective(trial):
    gpu_id = 0
    model = L.Classifier(create_model(trial))
    model.to_gpu(gpu_id)
    optimizer = create_optimizer(trial, model)

    rng = np.random.RandomState(0)
    batchsize = trial.suggest_categorical(
            'batchsize', [50, 100, 200, 500])
    train, test = chainer.datasets.fashion_mnist.get_fashion_mnist()
    train_iter = chainer.iterators.SerialIterator(train, batchsize)
    test_iter = chainer.iterators.SerialIterator(
            test, batchsize, repeat=False, shuffle=False)

    updater = chainer.training.StandardUpdater(
            train_iter, optimizer, device=gpu_id)
    trainer = chainer.training.Trainer(updater, (EPOCH, 'epoch'))

    optimizer_name = trial.params["optimizer"]
    if optimizer_name == "MomentumSGD":
        trainer.extend(chainer.training.extensions.ExponentialShift('lr', 0.1),
                   trigger=(5, 'epoch'))
    else:
        trainer.extend(chainer.training.extensions.ExponentialShift('alpha', 0.1),
                   trigger=(5, 'epoch'))

    trainer.extend(optuna.integration.ChainerPruningExtension(
            trial, "validation/main/loss", (PRUNER_INTERVAL, "epoch")))

    trainer.extend(chainer.training.extensions.Evaluator(
            test_iter, model, device=gpu_id))
    log_report_extension = chainer.training.extensions.LogReport(log_name=None)
    trainer.extend(
        chainer.training.extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]))
    trainer.extend(log_report_extension)

    trainer.run(show_loop_exception_msg=False)

    log_last = log_report_extension.log[-1]
    for key, value in log_last.items():
        trial.set_user_attr(key, value)

    val_err = 1.0 - log_report_extension.log[-1]['validation/main/accuracy']
    return val_err


if __name__ == "__main__":
    start = time()

    study = optuna.create_study()
    study.optimize(objective, n_trials=100)

    elapsed_time = time() - start
    print("elapsed_time:", elapsed_time)

    print('Number of finished trials: ', len(study.trials))

    print('Best trial:')
    trial = study.best_trial

    print('  Value: ', trial.value)

    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

    print('  User attrs:')
    for key, value in trial.user_attrs.items():
        print('    {}: {}'.format(key, value))

実験結果

ベストパラメータは以下のようになりました。

  • ネットワークの層数:2
  • 各ユニットのノード数:[unit1: 3145, unit2: 3846]
  • オプティマイザ:Adam
  • 学習率:4.156e-4
  • バッチサイズ:100

ベストパラメータでの結果は以下の通りです。

  • 訓練正解率:94.42%
  • テスト正解率:90.40%
  • 学習時間:70秒

無事、条件を満たしつつテスト正解率を90%以上にすることができました。

考察

今回は特徴②の枝刈りを用いました。
100回の試行での総学習時間は2904秒でした。
単純計算で70x100=7,000秒程かかっておかしくないので、枝刈りが非常に時間短縮に役立っていると考えられます。

ただ、オプティマイザは

  • MomentumSGDは、学習は遅いが汎化性能が高い
  • Adamは、学習は速いが汎化性能が低い

という特徴があるので、「MomentumSGDで学習が遅くても最終的に正解率が良くなるもの」も枝刈りされてしまった可能性があります。
そういった意味では、この枝仮はMomentumSGDにとって不利であったかもしれません。
そこらへんの枝刈り問題は、結構難しいらしいです。
(参照:OptunaのPruningが抱える課題

また、バッチサイズをある程度小さくすると正解率が良くなります。
上記のコードとは別に、今回の範囲よりも小さいバッチサイズを候補に入れたときは、正解率がより良くなりました。(バッチサイズが16のとき)
ただし、学習時間が100秒を余裕で超えてしまいました。
バッチサイズを小さくするとオプティマイザによる更新回数が増えるからなんでしょうね。
ここらへんは、時間との相談です。

まとめ

Optunaは、ほんと便利です。
手動ではとてもできないようなハイパーパラメータを探し出してくれます。
ただし、ある程度のハイパーパラメータ調整に関する知識(モデルの大枠の組み方や学習率のスケジューリングなど)は必要です。
モデルやハイパーパラメータの大枠は自分で決めて、Optunaを使って細かい調整をしていくのが良い使い方なんだと思います。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away