4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Adam の eps について考えてみた

Posted at

AdaBelief の謎 eps

最近、新しい Optimizer として AdaBelief が発表されました。
AdaBelief の更新式は以下になります。

m_0, v_0 \gets 0 \\ 
m_t \gets \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 
s_t \gets \beta_2 s_{t-1} + (1 - \beta_1) (g_t - m_t)^2 \\ 
\hat{m}_t \gets \frac{m_t}{1 - beta_1^t} \\ 
\hat{s}_t \gets \frac{s_t + \epsilon}{1 - beta_2^t} \\ 
\theta_t = \theta_{t-1}  - \alpha \frac{\alpha \hat{m}_t}{\sqrt{\hat{s}} + \epsilon} 

Adam とは異なり、$\hat{s}_t$ に対して、なぜか $\epsilon$ を加えています。

# pytorch の実装を確認すると、$ \hat{s}_t $ ではなく、$ s_t $ に対して $ \epsilon $ を加えていますが...

これはなんで追加されているのだろうかという疑問があります。

一応、論文中に以下の言及がありますが、よくわかりません。

Note that an extra $\epsilon$ is added to $s_t$ during bias-correction, in order to better match the assumption that $s_t$ is bouded below (the lower bound is at leat $\epsilon$).

EAdam

一方 $\epsilon$ の位置を変えるだけで性能が良くなるという論文もでています。

EAdam では $v_t$ に $\epsilon$ を加えます。

eps の影響

そもそも、$\epsilon$ は 0 除算回避のための定数なので影響があっては困るのですが、実際には変更すると性能に非常に大きい影響があります。

以下の論文では通常あまり変更されない Adam の beta や eps も広く探索を行い、SGD以上の性能を出しています。

そこで、eps を入れる位置によってパラメータの更新幅に影響がどのように変わるかプロットしてみます。
通常、移動平均を使いますが、簡単化のためある時点の勾配の大きさにもを使って考えます。

プロットする式は以下の通りです。

Adam っぽい式

\frac{g}{\sqrt{g^2} + \epsilon}

EAdamっぽい式

\frac{g}{\sqrt{(g^2 + \epsilon)}}

eps.png

Adam の式では、$\epsilon$ として 1e-8の他に1e-4もプロットしています。
これは平方根の外から中に移動する場合に、事前に2乗すると影響が近くなると考えたためです。

1e-8 を使った場合の影響は大きく違います。

実験

Cifar-10 を ResNet-18 で学習した結果を比較します。実験条件は EAdam の論文から拝借しました。

パラメータ
$\alpha$ 1e-3
$\beta_2$ 0.9
$\beta_1$ 0.999
weight decay 5e-4
学習エポック数 200
バッチサイズ 128

150エポックで学習率を 0.1 倍にします。

比較は以下の3つで行います。

  1. AdamW(eps=1e-8)
  2. EAdam(eps=1e-8)
  3. AdamW(eps=7.254762501100119e-4)

weight decay はすべて AdamW 的な入れ方をしています。
3番目の条件はステップ幅が 0.9 になる大きさが $\epsilon$ が 1e-8 の EAdam とほぼ同じになるように調整したものです。

eps_y0.9.png

テストデータに対する正答率を以下に示します。

val_acc.png

最終 epoch の値を以下に示します。

条件 テストデータ正答率
Adam(eps=1e-8) 0.9377
EAdam(eps=1e-8) 0.9326
Adam(eps=7.254762501100119e-4) 0.9363

EAdam の論文と異なり EAdam の性能が低いです。
正答率の推移はAdamでは $\epsilon$ を変えても変わらず、EAdam で常に若干低いという傾向があります。

次に、テスト損失と学習データ正答率を示します。

val_loss.png
train_acc.png

こちらは比較的、EAdam と $\epsilon$ を調整した Adam が似たような動きをしています。

# 学習損失は記録し損ねました

結論

EAdam に変更して性能が向上するという実験の前提が崩れているので語ることがあるのかという感じですが…

以下のことは言えそうです。

  • $\epsilon$ を調整しただけでは、Adam と EAdam で同じ挙動になるとはいえない
  • $\epsilon$ の挿入位置としてどこが良いのかは、結局広くハイパーパラメータを探索しないとよくわからない

# 何も言ってないに等しい気もしますが…

コード

実験に使ったコードを最後に記載します。
このコードをそのまま実行すると学習損失を記録しないため、注意してください。

import math
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18
import pytorch_lightning as pl

class EAdam(torch.optim.Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
                
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Perform stepweight decay
                p.mul_(1 - group['lr'] * group['weight_decay'])

                # Perform optimization step
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('AdamW does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                # EAdam
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2).add_(group['eps'])
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2))

                step_size = group['lr'] / bias_correction1

                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

class LitModel(pl.LightningModule):
    def __init__(self, optim_class, eps):
        super().__init__()
        self.save_hyperparameters()

        net = resnet18(num_classes=10)
        mods = []
        for name, module in net.named_children():
            if name == 'conv1':
                module = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if isinstance(module, torch.nn.MaxPool2d):
                continue
            if isinstance(module, torch.nn.Linear):
                mods.append(torch.nn.Flatten(1))
            mods.append(module)

        self.resnet18 = torch.nn.Sequential(*mods)
        self.train_acc = pl.metrics.Accuracy()
        self.valid_acc = pl.metrics.Accuracy()

    def forward(self, x):
        return self.resnet18(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(y_hat, y)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.valid_acc(y_hat, y)

        self.log('valid_loss', loss)
        self.log('valid_acc', self.valid_acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = self.hparams.optim_class(self.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=self.hparams.eps, weight_decay=5e-4) 
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), 
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train)

        return torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

    def val_dataloader(self):
        transform_test = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
        return torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

def main(optim_class, eps):
    pl.seed_everything(1234)

    model = LitModel(optim_class, eps)
    trainer = pl.Trainer(
        max_epochs=200,
        gpus=1,
        benchmark=True,
        progress_bar_refresh_rate=0
    )
    trainer.fit(model)

main(torch.optim.AdamW, 1e-8)
main(EAdam, 1e-8)
main(torch.optim.AdamW, 0.0007254762501100119)
4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?