はじめに
以前、以下の記事で固定学習率 + EWA の性能が良いと述べました。
実際に実行してみたところ、最終的な性能では ScheduleFree よりも良い結果になりました。しかしながら、序盤の性能では ScheduleFree の方が良いものとなりました。
その原因について検討を行い、その原因の一部について改善案を提案します。
pytorch での EWA の実現方法
まず、前提として pytorch で EWA を実現する方法について説明を行います。以下のページを読んでくださいで終わる話にはなりますが、AveragedModel を使うことで簡単に EWA を実現できます。
このページのサンプル通り AveragedModel で torch.optim.swa_utils.get_ema_multi_avg_fn を使うことで EWA となります。
# Compute exponential moving averages of the weights and buffers
ema_model = torch.optim.swa_utils.AveragedModel(model,
torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
序盤の性能低下要因について
Schedule Free の方が序盤の性能が良い原因として以下の要因がありました。
- Warmup start の考慮
- pytorch の EMA の実装の問題
ただし、この2点を改善しても、なお Schedule Free の方が序盤の性能は良かったため、まだ他にも要因がありそうです。
Warmup start の考慮
Schedule Free の x の更新式は以下の通りです。
$$
x_{t+1} = (1 - \frac{1}{t + 1})x_t + \frac{1}{t + 1}z_{t + 1}
$$
しかし、実際にソースコードを確認すると、それぞれのステップの z に対して学習率を使って重み付けをしています。
つまり、Schedule Free では Warmup start を行った際に、序盤の少ないステップ幅で更新した重みの比重を少なくして、加重平均を計算しています。この処理により序盤の性能が向上します。
pytorch の AveragedModel の枠組みでこの仕組みを EWA に実装しようとすると、Warmup 時の学習率を AveragedModel の multi_avg_fn 側に通知する仕組みが必要となり、非常に面倒です。実際に実装したところ、モジュール間の依存が激しいものとなったため、廃棄しました。長く実行すれば序盤の影響はそのうち消えるので、そこまでするのもなと思いまして
pytorch の EWA の実装の問題
次に pytorch の EWA の実装の問題なります。
pytorch の AveragedModel では最初の AveragedModel#update_parameters() の呼び出しの際にパラメータの値を丸ごとコピーします。そして、2回目以降の AveragedModel#update_parameters() の呼び出しで、multi_avg_fn を呼び出して、パラメータを更新します。
しかし、pytorch の torch.optim.swa_utils.get_ema_multi_avg_fn() では単純に引数で与えた減衰率を使ってパラメータを混ぜ合わせているだけなので、最初にコピーしたパラメータの比率が非常に大きなものとなります。学習初期はこの影響が非常に大きく性能が伸びません。
指数移動平均にて最初の要素の加重が大きくならないように補正を行っているものとして Adam が有名です。Adam では以下の処理を行うことで最初の要素の加重を大きくならないようにしています。
- 指数移動平均の初期値を 0 とする
- 最初の処理において、そのまま値をコピーする等の特別な処理を行わない。新しい要素に対しては常に $1 - \beta$ との積を計算する
- 計算した指数移動平均に対して、$\frac{1}{1 - \beta^t}$ を使って補正を行う(t はステップ数)
具体的な式は以下の通りです。
$$
\begin{align}
m_t &= \beta_1 m_{t - 1} + (1 - \beta_1)g_t \\
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t}
\end{align}
$$
このような補正を EWA でも行えば、序盤の性能の改善が期待できます。
バイアス補正された EWA の実装方法
Adam と同様の補正を行う EWA を実装するにあたり、単純な方法としては補正を行う AveragedModel をゼロから実装するというのが一番単純な方法となります。そうは言っても、できれば AveragedModel はそのままで multi_avg_fn の作成だけで済ませたいところです。
そこで参考になったのがこの実装の Adam です。
通常、Adam では $m_t$, $v_t$ を保持しています。しかし、この実装では $\hat{m}_t$, $\hat{v}_t$ を保持するようになっています。そして、$\beta1$, $\beta2$ を使って $m_t$, $v_t$ を計算するのではなく以下の値を使って計算します。
beta1hat = beta1 * (1 - beta1 ** (step - 1)) / (1 - beta1**step)
beta2hat = beta2 * (1 - beta2 ** (step - 1)) / (1 - beta2**step)
この式だけ見てもよくわからないため、Adam の式から変形を行い、導出してみます。まず、$\hat{m}_t$ から $m_t$ を求めます。これは単に式を逆にするだけです。
$$
\begin{align}
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t} \\
m_t &= (1 - \beta_1^t)\hat{m}_t
\end{align}
$$
次に、以下の Adam の更新式の $m_{t-1}$ を上記の式で $\hat{m}_{t-1}$ に置き換えます。
$$
\begin{align}
m_t &= \beta_1 m_{t - 1} + (1 - \beta_1)g_t \\
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t}
\end{align}
$$
置き換えたものが以下になります。
$$
\begin{align}
m_t &= \beta_1 (1 - \beta_1^{t-1}) \hat{m}_{t - 1} + (1 - \beta_1)g_t \\
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t}
\end{align}
$$
さらに $m_t$ を消すと次のようになります。
$$
\begin{align}
m_t &= \frac{
\beta_1 (1 - \beta_1^{t-1}) \hat{m}_{t - 1} + (1 - \beta_1)g_t}{1 - \beta_1^t} \\
&= \frac{\beta_1 (1 - \beta_1^{t-1})}{1 - \beta_1^t} \hat{m} _{t - 1} + \frac{1 - \beta_1}{1 - \beta_1^t}{g_t}
\end{align}
$$
このとき、$\hat{m}_{t-1}$ の係数を $a$、$g_t$ の係数を $b$ とします。
$$
\begin{align}
a &= \frac{\beta_1 (1 - \beta_1^{t-1})}{1 - \beta_1^t} \\
b &= \frac{1 - \beta_1}{1 - \beta_1^t}
\end{align}
$$
$a$ を変形して $b$ を使った式にします。
$$
\begin{align}
a &= \frac{\beta_1 (1 - \beta_1^{t-1})}{1 - \beta_1^t} \\
&= \frac{\beta_1 - \beta_1^t}{1 - \beta_1^t} \\
&= \frac{(\beta_1 - 1) - (\beta_1^t - 1)}{1 - \beta_1^t} \\
&= \frac{- (1 - \beta_1) + (1 - \beta_1^t)}{1 - \beta_1^t} \\
&= - \frac{1 - \beta_1}{1 - \beta_1^t} + 1 \\
&= 1 - b
\end{align}
$$
都合が良いことに $a$ を $1 - b$ で表すことができました。
そのため、$\beta_1$、$\beta_2$ の代わりに $a$ の式を適用したものを使うことで、$\hat{m}_t$、$\hat{v}_t$ を直接計算することができます。
なお、Adam にてこのような実装にして何が嬉しいかといいますと、$\beta_1$、$\beta_2$ を実行中に変更できるようになります。
この手法を EWA でも利用することで EWA の序盤の性能を向上させることができます。
実装は以下の通りです。この実装ではべき乗の計算が少ないため、$a$ ではなく $1 - b$ を利用しています。
from typing import Callable, Union
import torch
from torch import Tensor
from torch.optim.swa_utils import PARAM_LIST
def get_bias_corrected_ewa_multi_avg_fn(decay: float) -> Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]:
if decay < 0.0 or decay > 1.0:
raise ValueError(
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
)
@torch.no_grad()
def ema_update(ema_param_list: PARAM_LIST,
current_param_list: PARAM_LIST,
num_averaged: Union[Tensor, int]) -> None:
decay_hat = 1.0 - (1.0 - decay) / (1.0 - decay**(num_averaged + 1))
# foreach lerp only handles float and complex
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay_hat)
else:
for p_ema, p_model in zip(ema_param_list, current_param_list):
p_ema.copy_(p_ema * decay_hat + p_model * (1 - decay_hat))
return ema_update
シンプルにこの関数を AveragedModel の multi_avg_fn にこの関数の結果を指定することで利用可能です。なお、私は decay として 0.9995 を利用して、すべてのステップで update_parameters() を呼び出して AveragedModel のパラメータの更新を行っています。ハイパーパラメータ探索はしていないのでもっと良い値はありそうですが、十分性能が良いです。
おわりに
pytorch の EWA にて序盤の性能低下要因の説明とその一部の改善案を提示としてバイアス補正済み EWA の実装を示しました。
以上