29
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Automatic Mixed Precision (AMP):PyTorchでの学習を高速化

Posted at

AMPを使うとNaNに出くわしてうまく学習できない場合があったので,そのための備忘録と,AMP自体のまとめ.あまり検索してもでてこない注意点があったので,参考になればうれしいです.

Averaged Mixed Precision(AMP)とは

PyTorch中の一部の計算を,通常float32で計算するところ,float16で計算して,高速化するための技術です.そんなことやると性能が下がるのでは?という疑いもありますが,NVIDIAのページとかを見ると,あまり下がらなさそうです.
基本的には,計算も計算結果もfloat16で持ちますが,パラメタはfloat32で持つので,テスト時にfloat16でしか計算できない,ということは起こりません.(全部float16で持つこともできますが,性能の低下が起きます.)

使い方

PyTorchのドキュメントの通りにやればいいです.基本的にGPU用ですが,autocast('cuda')の部分を'cpu'に変えればCPUでもできます.

元のソースコード:

import torch
from modeling import Network
from torch.utils.data import DataLoader, TensorDataset
from torch import nn

model = Network()
optimizer = torch.optim.Adam(model.parameters, lr=args.lr)

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()

for epoch in range(args.epoch):
    for x, y in loader:
        optimizer.zero_grad()

        # forward
        y = model(x)
        loss = criterion(y)

        # backward
        loss.backward()

        # 大きすぎる勾配をClip
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)

        # パラメタの更新
        optimizer.step()

...

AMP使用時:

import torch
from modeling import Network
from torch.utils.data import DataLoader, TensorDataset
from torch import nn

model = Network()
optimizer = torch.optim.Adam(model.parameters, lr=args.lr)

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()

# scaler for gpu training

scaler = torch.cuda.amp.GradScaler(init_scale=args.init_scale)

for epoch in range(args.epoch):
    for x, y in loader:
        optimizer.zero_grad()
        # forward
        with torch.amp.autocast('cuda', dtype=torch.float16):
            y = model(x)
            loss = criterion(y)

        # backward
        scaler.scale(loss).backward()

        # クリップ時に正しくできるように一度スケールを戻す
        scaler.unscale_(optimizer)
        # 大きすぎる勾配をクリップ
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)

        # パラメタの更新
        scaler.step(optimizer)
        # スケールの更新
        scaler.update()

解説

  • autoscast
    この環境の下で行う計算はAMPで型がキャストされるようになります.ただし,対応している演算が限られていたり,計算の精度上やるとよくないもの(Batch Normalizationとか)は自動でスルーしてくれます.もちろん深層学習で一番出てくる行列積はキャストされます.

  • GradScaler
    勾配をスケール(大きくする)するもので,実はかなり重要なポイントです.具体的には,勾配がアンダーフローしてしまうのを防ぐ役割を持っています.
    float16で表現できる桁数は限られているので,小さい数値はアンダーフローで消えてしまいます.特に深層学習で顕著なのは勾配計算で,誤差逆伝播において連鎖率により勾配は掛け合わされていくので,小さくなりやすいです(勾配消失).表現できる桁数が少ないfloat16では勾配消失の影響がさらに大きくなり,float32で消えなかった勾配がアンダーフローで完全に消えてしまうことがあります.
    これを解決するのがGradScalerで,損失をかなり大きくしてから逆伝播して(scaler.scale(loss).backward()),逆伝播が終わってからもとの勾配の大きさに直してパラメタを更新(scaler.step(optimizer))します.こうすることで,逆伝播の計算時のアンダーフローを防げます.

  • scaler.unscale_
    勾配を直接操作する場合には,勾配の大きさを一度戻してから操作をします.例では勾配のクリッピング(勾配の最大値を決めておいて,超えたら小さくする)を挙げていますが,勾配が大きいままクリッピングすると間違った結果(おそらくすべてクリッピングされる)が出てしまいます.

注意点

学習中にNaNに出くわして,学習がうまくいかないことがしばしばあったので,実は最近まで使っていませんでした.その原因を見つけたので,ここにたどり着く方がいればご参考になれば幸いです.

  • NaNの原因
    結論から言うと,GradScalerにありました.GradScalerで勾配を大きくしすぎたがために,逆伝播中にオーバーフローしたのが原因でした.初期値では,GradScalerは65536をlossにかけるので,もともとの損失が小さくない場合,オーバーフローの原因になります.
  • 直し方
    単純にGradScalerでかける値を小さくしました.私の場合,4096にしたら大丈夫になりました.(2の指数なのは,アンダーフローを防ぐためなのでビットが動けばいいからです.)具体的には,torch.cuda.amp.GradScaler(init_scale=4096)みたいにしました.
  • 私の症状
    もし同じ症状の方が検索した場合に,引っかかるようにその時のエラーメッセージを置いておきます.
    RuntimeError: Function 'MmBackward0' returned nan values in its 0th output.

まとめ

PyTorchでの学習を高速化するAutomatic Mixed Precisionについて解説しました.GradScalerでかけられる値が大きすぎると,アンダーフローを防ぐどころか,逆にオーバーフローするので,気を付けてください.

29
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?