Deep learningのOptimizerのまとめを読んだので早見表を作りました
読んだやつ(記事中の図は当論文から引用しています)
- An overview of gradient descent optimization algorithms (Ruder, 2016)
- On Empirical Comparisons of Optimizers for Deep Learning (Choi et al., 2019)
勾配降下法は大きく分けて3種類
区分 | 1itrで使う サンプル |
良い所 | 良くない所 |
---|---|---|---|
Batch Gradient Descent |
全部 | ・エポックを進めれば大域 or 局所最適解に収束 | ・激重 ・大量のメモリ要 |
Stochastic Gradient Descent |
1コ | ・非凸関数でも大域解に行ける可能性 | ・エポックを進めると目的関数が大きく上下する |
Mini-Batch Gradient Descent |
バッチサイズNコ | ・上2つのいいとこ取り | ・バッチサイズNは自分で決める必要がある |
- バッチサイズは全体の1/10位のオーダーで、2の倍数にすることが多い
- 慣例ではMini-Batch Gradient DescentのこともSGD(Stochastic Gradient Descent)と呼ぶ
- 以下もそれに習います
いろんなSGDを一言で紹介
バニラSGD
名前 | 一言で | 良い所 | 良くない所 |
---|---|---|---|
SGD | ・上記の通り | ・上記の通り | ・スケールの大小によっては学習が進まない次元があるかも |
Momentum系(要するに「さっき進んだ方に進みやすい」)
名前 | 一言で | 良い所 | 良くない所 |
---|---|---|---|
Momentum | ・運動量項をもたせる | ・次元のスケール差に左右されにくい | ・鞍点から抜けられない |
NAG | ・未来の情報を与えたMomentum | ・↑より効率良 | ・鞍点から抜けられない |
Adagrad系(要するに「めっちゃ動いた次元は学習率を下げる」)
名前 | 一言で | 良い所 | 良くない所 |
---|---|---|---|
Adagrad | ・勾配二乗和の累積が大きい次元は学習率を下げる | ・スパースデータに向いてる ・学習率は初期値だけ決めればOK |
・学習率が下がり続けるので学習が進まなくなる |
Adadelta | ・改良Adagrad ・勾配二乗和に時間ペナルティ |
・学習率単調減少を抑える ・学習率の初期値を決めなくて良い |
・ハイパラが多い |
RMSprop | Adadeltaとほぼ同じ (同じ時期に別の人が作った) |
同上 | 同上 |
Adam系(AdagradとMomentumのハイブリッド)
名前 | 一言で | 良い所 | 良くない所 |
---|---|---|---|
Adam | ・AdagradとMomentumのハイブリッド | ・いいとこ取り | ・ハイパラが多い |
Adamax | ・変形Adam ・L2の代わりにL∞ノルム |
同上 | 同上 |
NAdam | ・AdagradとNAGのハイブリッド | 同上 | 同上 |
ようするに
- Momentum系は目的関数の勾配に項を足し引きする手法
- Adagrad系は目的関数の学習率スケジューリングを自動化する手法
- だから両立可能(その代わりハイパラは多くなる)
実験の結果
- きちんとハイパラチューニングすればAdam系がつよい(NAdamが一番性能が良かった)
- テキトーにやったらバニラSGD+学習率スケジューリングにも負けることも
- Optimizerを比較するなら、ハイパラの検索スペースも明記してね
↑分散の違う二次元ガウス分布(バニラSGDだとダメだけどMomentumだとうまくいく例)
↑鞍点(MomentumだとダメだけどAdagrad系・Adam系だとうまくいく例)
おまけ : 鞍点のプロットのコード
saddle_point.ipynb
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(1,1,1, projection='3d')
x = np.arange(-1, 1, 0.05)
y = np.arange(-1, 1, 0.05)
x, y = np.meshgrid(x, y)
ax.scatter(x, y, x**2 - y**2, s=2, marker=".")
ax.scatter(0, 0, 0, color="red", s=10, marker="o")
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
plt.savefig("./saddle_point.jpg")
plt.show()