LoginSignup
24
20

More than 3 years have passed since last update.

Deep learningのOptimizer早見表

Last updated at Posted at 2020-04-01

Deep learningのOptimizerのまとめを読んだので早見表を作りました

読んだやつ(記事中の図は当論文から引用しています)
- An overview of gradient descent optimization algorithms (Ruder, 2016)
- On Empirical Comparisons of Optimizers for Deep Learning (Choi et al., 2019)


勾配降下法は大きく分けて3種類

区分 1itrで使う
サンプル
良い所 良くない所
Batch
Gradient Descent
全部 ・エポックを進めれば大域 or 局所最適解に収束 ・激重
・大量のメモリ要
Stochastic
Gradient Descent
1コ ・非凸関数でも大域解に行ける可能性 ・エポックを進めると目的関数が大きく上下する
Mini-Batch
Gradient Descent
バッチサイズNコ ・上2つのいいとこ取り ・バッチサイズNは自分で決める必要がある
  • バッチサイズは全体の1/10位のオーダーで、2の倍数にすることが多い
  • 慣例ではMini-Batch Gradient DescentのこともSGD(Stochastic Gradient Descent)と呼ぶ
  • 以下もそれに習います

いろんなSGDを一言で紹介

バニラSGD

名前 一言で 良い所 良くない所
SGD ・上記の通り ・上記の通り ・スケールの大小によっては学習が進まない次元があるかも

Momentum系(要するに「さっき進んだ方に進みやすい」)

名前 一言で 良い所 良くない所
Momentum ・運動量項をもたせる ・次元のスケール差に左右されにくい ・鞍点から抜けられない
NAG ・未来の情報を与えたMomentum ・↑より効率良 ・鞍点から抜けられない

Adagrad系(要するに「めっちゃ動いた次元は学習率を下げる」)

名前 一言で 良い所 良くない所
Adagrad ・勾配二乗和の累積が大きい次元は学習率を下げる ・スパースデータに向いてる
・学習率は初期値だけ決めればOK
・学習率が下がり続けるので学習が進まなくなる
Adadelta ・改良Adagrad
・勾配二乗和に時間ペナルティ
・学習率単調減少を抑える
・学習率の初期値を決めなくて良い
・ハイパラが多い
RMSprop Adadeltaとほぼ同じ
(同じ時期に別の人が作った)
同上 同上

Adam系(AdagradとMomentumのハイブリッド)

名前 一言で 良い所 良くない所
Adam ・AdagradとMomentumのハイブリッド ・いいとこ取り ・ハイパラが多い
Adamax ・変形Adam
・L2の代わりにL∞ノルム
同上 同上
NAdam ・AdagradとNAGのハイブリッド 同上 同上

ようするに
- Momentum系は目的関数の勾配に項を足し引きする手法
- Adagrad系は目的関数の学習率スケジューリングを自動化する手法
- だから両立可能(その代わりハイパラは多くなる)

実験の結果
- きちんとハイパラチューニングすればAdam系がつよい(NAdamが一番性能が良かった)
- テキトーにやったらバニラSGD+学習率スケジューリングにも負けることも
- Optimizerを比較するなら、ハイパラの検索スペースも明記してね


Screenshot from 2020-04-01 21-53-43.png
↑分散の違う二次元ガウス分布(バニラSGDだとダメだけどMomentumだとうまくいく例)

saddle_point.jpg

↑鞍点(MomentumだとダメだけどAdagrad系・Adam系だとうまくいく例)


おまけ : 鞍点のプロットのコード

saddle_point.ipynb

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(1,1,1, projection='3d')

x = np.arange(-1, 1, 0.05)
y = np.arange(-1, 1, 0.05)
x, y = np.meshgrid(x, y)

ax.scatter(x, y, x**2 - y**2, s=2, marker=".")
ax.scatter(0, 0, 0, color="red", s=10, marker="o")

ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

plt.savefig("./saddle_point.jpg")
plt.show()
24
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
20