SAM(Sharpness-Aware Minimization)
SAMについては下記などで詳しく紹介されているのでここでは説明を割愛.
SAMは2回のforward-backwardを必要とするという点から、
性能面としてはよいものではある(と思う)が、実装をなにかといじる必要がでてくる.
今回記述するAMP(AutoMixedPrecision)
についても同様である.
AMP(AutoMixedPrecision)
これも公式が一番詳しいとは思うので割愛.
通常はfloat(fp32)であるが、Half(fp16)化によって演算速度アップや省メモリになるためのもの.
AMPとSAMの併用
PyTorchのAMPで利用されるGradScalerはstep(optimizer)
を呼出してパラメータを更新する.
しかし、よく利用させていただいているSAMの実装では通常のoptimizer.step()
ではなく、
first_step()
とsecond_step()
を呼び出すようになっているため、
そのままGradScalerに渡すとassertion errorを生じてしまう.
この問題を回避するため下記実装例のような形をとる.
具体的にはoptimizer.step
を差替えるという方法でAMPを使う.
実装例
import torch
class SAM(torch.optim.Optimizer):
#
# SAMの実装については https://github.com/davda54/sam などを参照
#
def train_loop(model, loader, loss_fn, scaler, optim):
value = 0
count = 0
model.train()
for i, itr in enumerate(loader):
x, t = itr
bs = x.shape[0]
x = x.to(CFG.device)
t = t.to(CFG.device).long()
if isinstance(optim, SAM):
# SAM 1st Path
optim.step = optim.first_step
with autocast(enabled=scaler._enabled):
y = model(x)
loss = loss_fn(y, t)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
# SAM 2nd Path
optim.zero_grad()
optim.step = optim.second_step
with autocast(enabled=scaler._enabled):
y = model(x)
l = loss_fn(y, t)
scaler.scale(l).backward()
scaler.step(optim)
scaler.update()
else:
# Default Optimizer
optim.zero_grad()
with autocast(enabled=scaler._enabled):
y = model(x)
loss = loss_fn(y, t)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
value += bs * loss.item()
count += bs
value = value / count if count > 0 else 0
return value
またコードとしてはOptimizerがSAMを継承しているかをisinstance
を用いて判断する.
SAMもしくはそのスーパークラスであれば2回のbackwardをするパスに入るようにしている.
これによりOptimizerに依存するコードの変更は少なくできるが、
逆にlossなどの変更点が増えてしまうなどのデメリットがある.
もっと良い書き方があるようにも思うが思いつかない....