前々から気になっていた問題で、「転移学習をするのにどのオプティマイザーを使うのが良くて、学習率はどのぐらいにしたらいいのか」という点があります。せっかくOptunaという便利なものが出たので、これを使って探してみました。Optunaの枝刈りについても見てみました。
環境:Optuna 0.4.0
問題設定
CIFAR-10を転移学習を行います。MobileNetを使い最後の5レイヤーのみ訓練させました。本当はVGGやInceptionでも調べる予定だったのですが、あまりに時間がかかる(1回50エポック×30試行やるだけでColab TPUで10時間以上かかる)ので、MobileNetだけで力尽きました。コードは載せておくので興味あったらやってみてください。
コード全体:https://gist.github.com/koshian2/497cf82479c6f9d1d92d19d400355705
Optunaのメインコードは以下の通りです。
def optuna_finding(network):
def objective(trial):
# ハイパーパラメータ(オプティマイザーと学習率を調べる)
optimizer = trial.suggest_categorical("optimizer", ["sgd", "momentum", "rmsprop", "adam"])
learning_rate = trial.suggest_loguniform("learning_rate", 1e-7, 1e0)
K.clear_session()
hist = train(network, optimizer, learning_rate, trial)
return 1.0 - np.max(hist["val_acc"])
train(…)は具体的な訓練が入っている関数です。オプティマイザーと学習率をパラメーターとし、
- SGD, Momentum, RMSProp, Adamの4種類
- 学習率は1e-7~1まで
で調べさせます。
Optunaのやっていることは、「試行が与えられたときにパラメーターを設定し、評価値を受け取り、いい感じに最適化する」だけなので、PFN製だからChainer専用というのは決してありません。Chainerで使うともう少しいいことあるかもしれませんが、OptunaがディープラーニングのフレームワークではChainerでしか使えないというのは100%嘘です。Kerasでの枝刈りは以下のようにコールバック(on_epoch_end)に挟むとできます。
class OptunaCallback(Callback):
def __init__(self, trial):
self.trial = trial
def on_epoch_end(self, epoch, logs):
current_val_error = 1.0 - logs["val_acc"]
self.trial.report(current_val_error, step=epoch)
# 打ち切り判定
if self.trial.should_prune(epoch):
raise optuna.structs.TrialPruned()
コールバックの引数に試行のインスタンスを与えましょう。trial.report(current_val_error, step=epoch)
で中間の値をOptunaに報告するだけで、あとはtrial.should_prune(epoch)
の関数を呼び出すだけで、Optunaが打ち切る(枝刈り)すべきか勝手に判定してくれます。
さて、この問題(転移学習のオプティマイザーと学習率)の人間側のインサイトは、あくまで個人的な経験則ですが「RMSPropやMomentumで低めの学習率が良いのではないか」というところです。ただ、あくまでそれは思い込みでしかない可能性も大いにあるので、Optuna先生に思い込みを矯正してもらいましょう。
ちなみに、人力でMobileNetを転移学習すると、固定するレイヤー数やエポック数が異なって正確な比較はできませんが、このコードと同じミニバッチ1024でテスト精度93.37%出ます1。このときの設定はRMSPropで学習率1e-4でした。
試行1~5
個々の試行を見ていきましょう。今回1~5回目では訓練の打ち切りは観測できませんでした。
試行 | ステータス | エラー率 | 学習率 | オプティマイザー |
---|---|---|---|---|
1 | TrialState.COMPLETE | 0.1822 | 4.91E-06 | adam |
2 | TrialState.COMPLETE | 0.4395 | 1.23E-05 | momentum |
3 | TrialState.COMPLETE | 0.2097 | 0.000100423 | momentum |
4 | TrialState.COMPLETE | 0.1336 | 0.001331091 | momentum |
5 | TrialState.COMPLETE | 0.8691 | 4.24E-06 | sgd |
2番目や5番目は人間の目から見たら打ち切ってもいいような気がしますが、まだ学習曲線のデータが集まっていないのか50エポック律儀に訓練させていました。
##試行6~10
ここから途中打ち切りが始まります。「TrialState.PRUNED」と書いてあるのが打ち切られた試行です。
試行 | ステータス | エラー率 | 学習率 | オプティマイザー |
---|---|---|---|---|
6 | TrialState.COMPLETE | 0.1291 | 0.001700487 | momentum |
7 | TrialState.COMPLETE | 0.1723 | 0.000224701 | momentum |
8 | TrialState.PRUNED | 0.9184 | 1.27E-07 | rmsprop |
9 | TrialState.COMPLETE | 0.1279 | 0.001740973 | momentum |
10 | TrialState.PRUNED | 0.8821 | 3.00E-07 | momentum |
学習曲線を書いてみましょう。
打ち切るときはかなり思い切った打ち切りをしていて、2エポックで打ち切っています。イメージ的には、「これはもうダメなパターンだから、さっさと打ち切りなさい」ということでしょうか。こういう思い切った枝刈りするの面白いですね。
追記:Optuna 0.5.0ではもう少し冷静な打ち切りをする(10エポックぐらいまで見る)傾向があります
##試行11~20
ここからは打ち切り大量発生です。11回目~20回目は全部打ち切りです。
試行 | ステータス | エラー率 | 学習率 | オプティマイザー | 打ち切ったエポック |
---|---|---|---|---|---|
11 | TrialState.PRUNED | 0.7235 | 0.162072869 | adam | 2 |
12 | TrialState.PRUNED | 0.8990 | 0.739446454 | adam | 2 |
13 | TrialState.PRUNED | 0.8937 | 6.25E-06 | sgd | 2 |
14 | TrialState.PRUNED | 0.2584 | 0.001552035 | sgd | 16 |
15 | TrialState.PRUNED | 0.8843 | 2.13E-07 | momentum | 2 |
16 | TrialState.PRUNED | 0.8752 | 1.10E-05 | rmsprop | 2 |
17 | TrialState.PRUNED | 0.4584 | 0.523973569 | momentum | 5 |
18 | TrialState.PRUNED | 0.9020 | 2.27E-06 | rmsprop | 2 |
19 | TrialState.PRUNED | 0.5996 | 0.202372759 | rmsprop | 4 |
20 | TrialState.PRUNED | 0.2402 | 0.000164255 | momentum | 20 |
11番目や16番目は設定値だけ見れば、人間的には必ずしも悪くはなさそうなのに、早い段階で打ち切ってしまうのですね。
学習曲線を書いてみると、打ち切ったのもわからなくはないなという感じはします。
##試行21~30
21回目以降も大半は打ち切りですが、若干打ち切りに対して冷静になるようになります。
試行 | ステータス | エラー率 | 学習率 | オプティマイザー | 打ち切ったエポック |
---|---|---|---|---|---|
21 | TrialState.COMPLETE | 0.0997 | 0.936072877 | sgd | - |
22 | TrialState.COMPLETE | 0.1107 | 0.002007988 | rmsprop | - |
23 | TrialState.PRUNED | 0.2123 | 0.23132905 | sgd | 12 |
24 | TrialState.PRUNED | 0.3779 | 0.961493955 | sgd | 4 |
25 | TrialState.PRUNED | 0.5878 | 0.720767242 | sgd | 2 |
26 | TrialState.PRUNED | 0.1949 | 0.874594255 | sgd | 17 |
27 | TrialState.PRUNED | 0.5594 | 0.642876547 | sgd | 2 |
28 | TrialState.PRUNED | 0.2917 | 0.198072291 | sgd | 5 |
29 | TrialState.PRUNED | 0.1984 | 0.347661825 | sgd | 16 |
30 | TrialState.PRUNED | 0.3493 | 0.689101571 | sgd | 4 |
SGDで良い値がでたからか、SGDがお気に入りになってしまいました。エラー率10%切りまでいかなくても、5回目~10回目でMomentumがそこそこ良い結果を残したのはとうに忘れているのではないでしょうか。
どのオプティマイザーがお気に入りになるかは、おそらく試行(偶然引いた良いケースに引きづられる?)によりけりかと思われるので、別に試したら違うオプティマイザーをよく使うということもあるかもしれません。しかし、偶然引いた良い結果に記憶が引きづられるのは人間も同じなので、Optunaが悪いということではないと思います。
学習曲線です。さすがに20回試行すると適切な打ち切りができるようになるのではないでしょうか。この打ち切りはかなり納得しやすいです。
SGDってファインチューニングだとこんなに良かったのですね。系列2がRMS Propで、正直RMS Propの学習率をもう少し下げてみたい感は否めないですが、SGDのほうがより良い解で停止しているのがわかります。これはSGDやRMSPropのアルゴリズムを考えれば当然です。RMSPropのように方向修正しながら進むのは、初期の訓練では訓練が加速しますが、最終的にはSGDの場合は停止するあたりエリアをぐるぐる回ることになるので、RMSPropはこのような副作用もあっても仕方ないと思われます。
SGDってもっと波打つ印象があって、これはバッチサイズが1024と一般的なGPUのケースと比べて大きいバッチサイズでやっているせいもあると思いますが、正直なところ自分の脳内で勝手に「枝刈り」した感が否めません。Optunaが「SGDのほうがいいよ」と偶然ながら提案してきたので、確かにそれはあるのかもなというところでした。
ファインチューニングでもオプティマイザーがSGDなら、学習率が1に近いような高い学習率でも、ちゃんと精度出るのですね。これは自分にとっては意表を突かれました。
ふりかえり
やってみたところ、試行回数20回目以降が本番のようで、最初の5回ぐらいは学習曲線のデータ集めをしているようでした。やはりOptunaをディープラーニングで使いこなすなら、計算リソースが一番のネックになってくるのではないでしょうか。
今回見たように、5回×50エポック分を打ち切りなしの学習曲線集めに使っていると(公式ドキュメントを見ていたら、MedianPrunerのデフォルトの設定で、最初の5回をスタートアップとして使っているとのことでした。この値を変えればもっと早く枝刈りが始まるかもしれません)、計250エポック分をバーンインとして計算しなければいけません。250エポックをバーンインとして要求するのは、Tesla P100が1000個ぐらいあれば一瞬で終わるでしょうが、普通はそんなに計算資源が潤沢ではないので、広い意味で「富豪的プログラミング」という印象は否めません。もちろんOptunaの訓練中断・再開機能を使えば多少はマシになります。
もしそこまで計算資源が潤沢でないのなら、少なくともディープラーニングにおいては、どれが悪いケースでどれが良いケースかというインサイトや視点を導入できたり、学習曲線の”転移学習”ができる人間のほうがまだ分があるのではないかなと思います2。前の記事で書いたように、計算リソースがディープラーニングほどは要求されないであろう、勾配ブースティングではOptunaがめちゃくちゃ強くなるだろうというのはこういう背景です。
Optunaのポイントはデータフレームに書き出したり、SQLiteと連携したり、容易に最適化の過程が可視化できることだと思います。自動でニューラルネットワークを作りたかったら多分AutoMLやNASを使ったほうがいいのではないでしょうか。Optunaは強化学習のアプローチではないはずなので、最適化の過程でみたらまだまだそういうのと比べると弱いと思います。
Optunaをニューラルネットワークのアーキテクチャを探すということにもできますが、それよりもハイパラを探させてその結果を人間が解釈する(つまりこのパラメーターを動かすと、精度がどの程度変わるかという理解、あるいはハイパラの変化に対して解がどのように分布しているかを知る)、つまりOptunaの試行からインサイトを得て、人間が学習するほうが有意義な使い方ができるのではないかと思います。もしろん計算リソースが湯水のように使えれば、ニューラルネットワークといえどリソースの暴力で殴り倒すこともできるでしょう。現にわれわれがRandom ForestやサポートベクターマシンをOptunaで殴り倒しているのがまさにそれです。
-
ハイパラに対しての局所解も疑問あります。例えば今回見たように、学習率とオプティマイザーをとっても、同一のオプティマイザーなら学習率に対して検証曲線がV字型になっても、それを他のオプティマイザーと組み合わせて比較すると、オプティマイザー間で学習率のスケールが異なって局所解が生まれるのではないでしょうか。ディープラーニングでは、理論考えた人が割と好き勝手にハイパラ突っ込んでいるので、へんてこりんなハイパラがあるとそれだけで最適解には行きにくくなるのかなと自分は思います。 ↩