最初からクライマックス
公式に書かれている利用方法は下.でもモデルが数値的に不安定な計算(Softmax, division by epsilon...)を含んでるといつかnanが出る.ちなみにTransformer, AttentionはSoftmaxを含んでいるので不安定.
# 公式の例.ダメではないがnanが出て死ぬことがある.
with torch.autocast():
# モデル計算とか誤差逆伝搬とか
...
AMP(torch.autocast)を使わなかったらnanが出ないのにな〜 という人は下のようにすると多分nanが出なくなる.
# nanがかなり出にくい.
with torch.autocast(dtype=torch.bfloat16):
# モデル計算とか誤差逆伝搬とか
...
以上.nanが出なくなったらおめでとう.
文章を読む人が嫌いな人への擬似コード
while nan appears without amp(autocast):
Don't use amp! Clean up your code!
try:
try:
use "torch.autocast()"
except nan appears:
use "torch.autocast(dtype=torch.bfloat16)" instead of "torch.autocast()"
except nan appears:
Sorry... please publish a new issue and run without autocast...
Device warn
bf16はAmpereアーキテクチャ以降の実装であることに注意してください。
For debug
torch.autocast(enabled=False)と設定すると簡単にamp機能をオフにできます.
ここからは蛇足
AMP公式情報
AMP(AutomaticMixedPrecision)についてはPyTorch公式ドキュメントとPyTorch公式サンプル例に詳しい内容はほぼ書いてあります.
ただしtorch.bfloat16を使うと良いとはどこにも書いてないので注意.
Advanced
Gradient Clippingを併用することでよりnanが出にくくなります.
公式以外の情報
なぜtorch.bfloat16を使うと良いのか
この問いには2つ理解すべきことがある.
- fp32, fp16, bf16の違い
- autocast中のnanの発生原因
1について,PyTorchデフォルトのfp32とautocastのデフォルトのfp16,そして今回のキーであるbf16の違いを分かりやすくまとめてくれた図が下.(tf32のことは気にしないで.)
浮動小数点は符号(sign)と指数部(range)と仮数部(precision)の3つが組み合わさって表現されている.
fp32と比べfp16は指数部も仮数部もbit数が減っており, 表現できる値域も狭まり 解像度も低下する.
一方bf16は仮数部は大きく減るものの指数部は減っておらず,値域はそのまま に解像度だけ低下している.
これがfp32, fp16, bf16の違いである.
次に2についてだが,ずばりautocast中でのnanの発生原因はautocastのデフォルトのfp16の値域が狭まくoverflowするためである.
例えばfp32での表現可能最大値はおよそ3.40×10^38だが,fp16の最大値は65504である.そのためfp32では十分計算可能な値でも65504を越える値はfp16では取り扱うことができずoverflowによるnanとなってしまうのだ.一方でbf16の表現可能最大値はほぼfp32と同じであるためoverflow問題はほぼ起きない.
torch.bfloat16を使うと代わりに何か問題が起きるのでは?
一般にunderflow問題が起きます.つまりは0付近の小さな値の取り扱いが苦手でほんの少しだけパラメータ更新することができなくなります.しかしながらこの問題はPyTorch公式が既に予測済みで,GradScalerを使うことで回避できます.公式サンプルにも堂々とAMPと一緒に書かれているのでセットで使うことが最初から想定されています.
そもそもなぜAMPを使うのか
AMPを使うとデフォルトのfp32からfp16 or bf16にモデルパラメータが変更され,GPU使用メモリ量が半分になるためbatchサイズを約2倍にでき,32->16bit計算になるので計算速度も向上するため.
追記
公式にbf16を使うと良いぞということがほんの少しだけ書いてありました.
Figure out by experimentation if your network is sensitive to range and/or precision of a format. For example fine-tuning bfloat16-pretrained models in float16 can easily run into range issues in float16 because of the potentially large range from training in bfloat16, so users should stick with bfloat16 fine-tuning if the model was trained in bfloat16.
from pytorch blog post
またredditに本記事と同様の意見が書かれてありました.
さらに言えば,多くのgoogle製モデルはbf16を利用しているみたいです.
BFloat16 offers better stability during training than FP16. Most google models are BFloat16 due to using TPUs, where BF16 is native. We're seeing more LLMs trained in BFloat16 out of superior stability (see the BigScience project by HuggingFace who noted better stability).