概要
深層学習モデルの大規模な訓練では、訓練の効率を最大化するために、大きなサイズのミニバッチを使用します。しかし、ミニバッチサイズを大きくするといくつかの問題のためにモデルの精度が劣化するという現象が知られています。
本論文が提案する**LARS(Layer-wise Adaptive Rate Scaling)**は、そのような問題に対処するための、もっともポピュラーな方法です。
書誌情報
- You, Yang, Igor Gitman, and Boris Ginsburg. "Large batch training of convolutional networks." arXiv preprint arXiv:1708.03888 (2017).
- https://arxiv.org/abs/1708.03888
Generalization Gap
なぜ、ミニバッチのサイズを大きくすることが問題なのでしょうか? この点を明らかにした論文1では、この問題をgeneralization gapと名付けています。
大きなミニバッチは、損失関数(Training Function)を最小化するために局所的な最適解(Sharp Minimum)への早期の収束をもたらします。一方、十分に小さなミニバッチは、損失関数を最小化するための解をすぐには見つけられず、結果的に広範囲を探索することになるため、大域的な最適解(Flat Minimum)への収束をもたらします。
Flat MinimumではValidationセットでの評価関数(Test Function)もそれなりに良い値になりますが、Sharp Minimumでは著しい劣化が生じることが知られています。
このような現象を、大きなバッチサイズがもたらすgeneralization gapと呼びます。
確率的勾配降下法
本題に入る前に、LARSの基礎となっている確率的勾配降下法について復習しておきます。
深層学習モデルのパラメータ$w$は、確率的勾配降下法(SGD)によって以下のように更新され、訓練されます。
w_{t+1}=w_{t}-\lambda \frac{1}{B} \sum_{i=1}^{B} \nabla L\left(x_{i}, w_{t}\right)
ここで、$B$はバッチサイズを指します。
ところで、少し話はそれますが、$B$を大きくすると先ほど説明したgeneralization gapが生じますが、それ以外にも$B$が大きくなることによる不具合が知られています。
通常、バッチサイズを大きくする目的は、訓練を高速化させることです。しかし、$B$が$k$倍されると、1エポックあたりのパラメータ更新回数は$\frac{1}{k}$になってしまいます。パラメータが収束に至るまでの更新回数が一定であるとすると、単にバッチサイズを$k$倍しただけでは、収束にかかるエポック数も$k$倍となってしまい、高速化という目的は果たせなくなります。
そこで、バッチサイズを大きくして高速な収束を目指す場合は、同時に学習率$\lambda$を大きくする必要があります。バッチサイズを$k$倍するのであれば、学習率も$k$倍する、というlinear LR scalingというテクニックが知られています。
しかし、linear LR scalingにも問題があります。特に、訓練の初期に学習率を大きくしすぎると、パラメータが発散してしまい、うまくモデルの訓練が進まないということが知られています。この問題を回避するためのテクニックとして、訓練の初期は学習率を低めにし、徐々に学習率を大きくしていくというLR warm-upが知られています。
この2つのテクニックを組み合わせた学習率のスケジューリングはLARSとよく組み合わされて使用されます。
LARSの理論
本題に戻ります。SGDの更新則のバッチサイズに関わらない部分だけ抜き出してみますと、以下のように書くことができます。
w_{t+1}=w_{t}-\lambda\nabla L\left(w_{t}\right)
ここで、訓練の初期にパラメータ$w$の大きさに対して、$\lambda\nabla L\left(w_{t}\right)$の大きさが過剰だと、パラメータ$w$が発散してしまってうまく訓練が進まないと考えられます。これは、linear LR scalingを行う際に問題となります。
ひとつのアプローチとして、$\lambda\nabla L\left(w_{t}\right)$を十分小さくするために、先に述べたLR warm-upを用いて、徐々に学習率$\lambda$を大きくしていくことで、初期のパラメータの発散を回避するということが考えられます。
しかし、LARSでは異なるアプローチを採用します。
以下は、AlexNetにBatchNormalizationを併用したモデルを訓練したときの1イテレーション目で、各レイヤーでのパラメータと勾配のL2ノルムがどのような値になるのかを示した表です。これをみてみると、$w$と$\lambda\nabla L\left(w_{t}\right)$のノルムの比である$\frac{\left|w^{l}\right|}{\left|\nabla L\left(w^{l}\right)\right|}$は5.76~1345という幅広い値をとっており、レイヤーごとにだいぶ異なる値になっていることがわかります。
LARSでは、$\frac{\left|w^{l}\right|}{\left|\nabla L\left(w^{l}\right)\right|}$の違いに関わらず、どのレイヤーでも安定した更新が行えるように、レイヤーごとに異なる学習率$\lambda^l$を導入しようというアプローチを採用します。
レイヤーごとのパラメータ$w^l$に対する学習率$\lambda^l$を、パラメータ更新のタイミングでその都度、以下のように決めます。ここで$\eta$はtrustと呼ばれ、パラメータの大きさに対してどの程度まで更新を許すのか、を意味するハイパーパラメータです。
\lambda^{l}=\eta \times \frac{\left\|w^{l}\right\|}{\left\|\nabla L\left(w^{l}\right)\right\|}
このレイヤーごとの学習率$\lambda^l$とグローバルな学習率$\gamma$を組み合わせ、SGDの更新則は、以下のように書き直すことができます。もちろん、グローバルな学習率$\gamma$には、LR warm-up、polynomial decayといったさまざまな学習率スケジューリングを適用することができます。
\triangle w_{t}^{l}=\gamma * \lambda^{l} * \nabla L\left(w_{t}^{l}\right)
以上のように、LARSの導入によって、大きなバッチサイズでも$w$が発散することなく、安定して更新されることが保証されます。
なお、LARSにおけるWeight Decayは、以下のような形で簡単に導入できます。
\lambda^{l}=\eta \times \frac{\left\|w^{l}\right\|}{\left\|\nabla L\left(w^{l}\right)\right\|+\beta *\left\|w^{l}\right\|}
LARSの実装(PyTorch)
2020年11月14日現在、いくつかの実装が公開されています。
いずれの実装も、PyTorch標準のOptimizerを継承またはラッピングすることでLARSを実現しています。
-
Keskar, Nitish Shirish, et al. "On large-batch training for deep learning: Generalization gap and sharp minima." arXiv preprint arXiv:1609.04836 (2016). ↩