3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchのAMPとSAMの併用

Last updated at Posted at 2022-09-18

SAM(Sharpness-Aware Minimization)

SAMについては下記などで詳しく紹介されているのでここでは説明を割愛.

SAMは2回のforward-backwardを必要とするという点から、
性能面としてはよいものではある(と思う)が、実装をなにかといじる必要がでてくる.
今回記述するAMP(AutoMixedPrecision)についても同様である.

AMP(AutoMixedPrecision)

これも公式が一番詳しいとは思うので割愛.

通常はfloat(fp32)であるが、Half(fp16)化によって演算速度アップや省メモリになるためのもの.

AMPとSAMの併用

PyTorchのAMPで利用されるGradScalerはstep(optimizer)を呼出してパラメータを更新する.
しかし、よく利用させていただいているSAMの実装では通常のoptimizer.step()ではなく、
first_step()second_step()を呼び出すようになっているため、
そのままGradScalerに渡すとassertion errorを生じてしまう.

この問題を回避するため下記実装例のような形をとる.
具体的にはoptimizer.stepを差替えるという方法でAMPを使う.

実装例

import torch

class SAM(torch.optim.Optimizer):
    #
    # SAMの実装については https://github.com/davda54/sam などを参照
    #

def train_loop(model, loader, loss_fn, scaler, optim):
    value = 0
    count = 0
    model.train()
    for i, itr in enumerate(loader):
        x, t = itr
        bs = x.shape[0]
        x = x.to(CFG.device)
        t = t.to(CFG.device).long()
        if isinstance(optim, SAM):
            # SAM 1st Path
            optim.step = optim.first_step
            with autocast(enabled=scaler._enabled):
                y = model(x)
                loss = loss_fn(y, t)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            # SAM 2nd Path
            optim.zero_grad()
            optim.step = optim.second_step
            with autocast(enabled=scaler._enabled):
                y = model(x)
                l = loss_fn(y, t)
            scaler.scale(l).backward()
            scaler.step(optim)
            scaler.update()
        else:
            # Default Optimizer
            optim.zero_grad()
            with autocast(enabled=scaler._enabled):
                y = model(x)
                loss = loss_fn(y, t)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        value += bs * loss.item()
        count += bs
    value = value / count if count > 0 else 0
    return value

またコードとしてはOptimizerがSAMを継承しているかをisinstanceを用いて判断する.
SAMもしくはそのスーパークラスであれば2回のbackwardをするパスに入るようにしている.
これによりOptimizerに依存するコードの変更は少なくできるが、
逆にlossなどの変更点が増えてしまうなどのデメリットがある.
もっと良い書き方があるようにも思うが思いつかない....

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?