AdaBelief の謎 eps
最近、新しい Optimizer として AdaBelief が発表されました。
AdaBelief の更新式は以下になります。
m_0, v_0 \gets 0 \\
m_t \gets \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
s_t \gets \beta_2 s_{t-1} + (1 - \beta_1) (g_t - m_t)^2 \\
\hat{m}_t \gets \frac{m_t}{1 - beta_1^t} \\
\hat{s}_t \gets \frac{s_t + \epsilon}{1 - beta_2^t} \\
\theta_t = \theta_{t-1} - \alpha \frac{\alpha \hat{m}_t}{\sqrt{\hat{s}} + \epsilon}
Adam とは異なり、$\hat{s}_t$ に対して、なぜか $\epsilon$ を加えています。
# pytorch の実装を確認すると、$ \hat{s}_t $ ではなく、$ s_t $ に対して $ \epsilon $ を加えていますが...
これはなんで追加されているのだろうかという疑問があります。
一応、論文中に以下の言及がありますが、よくわかりません。
Note that an extra $\epsilon$ is added to $s_t$ during bias-correction, in order to better match the assumption that $s_t$ is bouded below (the lower bound is at leat $\epsilon$).
EAdam
一方 $\epsilon$ の位置を変えるだけで性能が良くなるという論文もでています。
EAdam では $v_t$ に $\epsilon$ を加えます。
eps の影響
そもそも、$\epsilon$ は 0 除算回避のための定数なので影響があっては困るのですが、実際には変更すると性能に非常に大きい影響があります。
以下の論文では通常あまり変更されない Adam の beta や eps も広く探索を行い、SGD以上の性能を出しています。
そこで、eps を入れる位置によってパラメータの更新幅に影響がどのように変わるかプロットしてみます。
通常、移動平均を使いますが、簡単化のためある時点の勾配の大きさにもを使って考えます。
プロットする式は以下の通りです。
Adam っぽい式
\frac{g}{\sqrt{g^2} + \epsilon}
EAdamっぽい式
\frac{g}{\sqrt{(g^2 + \epsilon)}}
Adam の式では、$\epsilon$ として 1e-8の他に1e-4もプロットしています。
これは平方根の外から中に移動する場合に、事前に2乗すると影響が近くなると考えたためです。
1e-8 を使った場合の影響は大きく違います。
実験
Cifar-10 を ResNet-18 で学習した結果を比較します。実験条件は EAdam の論文から拝借しました。
パラメータ | 値 |
---|---|
$\alpha$ | 1e-3 |
$\beta_2$ | 0.9 |
$\beta_1$ | 0.999 |
weight decay | 5e-4 |
学習エポック数 | 200 |
バッチサイズ | 128 |
150エポックで学習率を 0.1 倍にします。
比較は以下の3つで行います。
- AdamW(eps=1e-8)
- EAdam(eps=1e-8)
- AdamW(eps=7.254762501100119e-4)
weight decay はすべて AdamW 的な入れ方をしています。
3番目の条件はステップ幅が 0.9 になる大きさが $\epsilon$ が 1e-8 の EAdam とほぼ同じになるように調整したものです。
テストデータに対する正答率を以下に示します。
最終 epoch の値を以下に示します。
条件 | テストデータ正答率 |
---|---|
Adam(eps=1e-8) | 0.9377 |
EAdam(eps=1e-8) | 0.9326 |
Adam(eps=7.254762501100119e-4) | 0.9363 |
EAdam の論文と異なり EAdam の性能が低いです。
正答率の推移はAdamでは $\epsilon$ を変えても変わらず、EAdam で常に若干低いという傾向があります。
次に、テスト損失と学習データ正答率を示します。
こちらは比較的、EAdam と $\epsilon$ を調整した Adam が似たような動きをしています。
# 学習損失は記録し損ねました
結論
EAdam に変更して性能が向上するという実験の前提が崩れているので語ることがあるのかという感じですが…
以下のことは言えそうです。
- $\epsilon$ を調整しただけでは、Adam と EAdam で同じ挙動になるとはいえない
- $\epsilon$ の挿入位置としてどこが良いのかは、結局広くハイパーパラメータを探索しないとよくわからない
# 何も言ってないに等しい気もしますが…
コード
実験に使ったコードを最後に記載します。
このコードをそのまま実行すると学習損失を記録しないため、注意してください。
import math
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18
import pytorch_lightning as pl
class EAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# Perform stepweight decay
p.mul_(1 - group['lr'] * group['weight_decay'])
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# EAdam
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2).add_(group['eps'])
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2))
step_size = group['lr'] / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
class LitModel(pl.LightningModule):
def __init__(self, optim_class, eps):
super().__init__()
self.save_hyperparameters()
net = resnet18(num_classes=10)
mods = []
for name, module in net.named_children():
if name == 'conv1':
module = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if isinstance(module, torch.nn.MaxPool2d):
continue
if isinstance(module, torch.nn.Linear):
mods.append(torch.nn.Flatten(1))
mods.append(module)
self.resnet18 = torch.nn.Sequential(*mods)
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
def forward(self, x):
return self.resnet18(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat, y)
self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.valid_acc(y_hat, y)
self.log('valid_loss', loss)
self.log('valid_acc', self.valid_acc, on_step=False, on_epoch=True)
def configure_optimizers(self):
optimizer = self.hparams.optim_class(self.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=self.hparams.eps, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
return [optimizer], [scheduler]
def train_dataloader(self):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train)
return torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
def val_dataloader(self):
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
return torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)
def main(optim_class, eps):
pl.seed_everything(1234)
model = LitModel(optim_class, eps)
trainer = pl.Trainer(
max_epochs=200,
gpus=1,
benchmark=True,
progress_bar_refresh_rate=0
)
trainer.fit(model)
main(torch.optim.AdamW, 1e-8)
main(EAdam, 1e-8)
main(torch.optim.AdamW, 0.0007254762501100119)