はじめに
GPU で学習スクリプトを書いていると、precision="bf16" や dtype=torch.float16 みたいなオプションが出てきて「結局何が違うの…?」となりがちです。
さらに最近は、強化学習 (RL) での LLM ファインチューニングに関して
BF16 precision が学習時と推論時のミスマッチを生んで不安定になる、だから FP16 にすると安定する
みたいな話も出てきていて、ますます混乱します。(arXiv)
この記事では
- FP32 / FP16 / BF16 の違い
- BF16 がなぜ流行っているのか
- さっきの RL 論文が何を言っているのか
- 実務で precision をどう選べばよいか
あたりを、GPU で学習するエンジニア目線でまとめてみます。
まずは登場人物:FP32 / FP16 / BF16
深層学習でよく出てくるフォーマットを以下に示します。
-
FP32 (float32, 単精度)
32 bit 浮動小数点。ほとんどのフレームワークのデフォルト。精度もレンジも十分だが、メモリと計算が重い -
FP16 (float16, IEEE half, 半精度)
16 bit 浮動小数点。1 bit 符号 + 5 bit 指数 + 10 bit 仮数。メモリは FP32 の半分で速いが、ダイナミックレンジが狭くアンダーフローしやすい(ウィキペディア) -
BF16 (bfloat16)
16 bit 浮動小数点。1 bit 符号 + 8 bit 指数 + 7 bit 仮数。指数部は FP32 と同じでレンジ広いが、仮数が少なくて精度は低い。Google TPU や最近の GPU でサポートが増えている(ウィキペディア)
どれも「浮動小数点」ですが、同じ 16 bit でも FP16 と BF16 は性格がかなり違います。(Jarvis Labs)
ビット構造から見る FP16 / BF16 の違い
ビットを分解すると、性格の違いがかなりはっきりします。
| フォーマット | 全体 bit 数 | 指数部 bit | 仮数部 bit (有効桁) | 特徴 |
|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | 精度もレンジも高いが重い |
| FP16 | 16 | 5 | 10 | レンジは狭いが精度は 3 桁ちょい |
| BF16 | 16 | 8 | 7 | FP32 並みのレンジだが精度は 2 桁台(ウィキペディア) |
まとめるとこんな感じになります。
-
FP16
- 指数が 5 bit なので、表現できる値の範囲は狭い
- その代わり仮数が 10 bit あるので、細かい値をそこそこ区別できる
-
BF16
- 指数が 8 bit と FP32 と同じなので、めちゃくちゃ広いレンジを扱える
- その代わり仮数が 7 bit しかなく、丸め誤差が大きい
だから
- オーバーフロー / アンダーフローしたくない → BF16 有利
- 丸め誤差を減らしたい → FP16 有利
というトレードオフになります。
なぜ BF16 が学習で流行っているのか
ここ数年で BF16 が一気にメジャーになったのには以下のような理由があります。
-
FP32 とほぼ同じレンジを保ちつつ、メモリ使用量を半減できる
BF16 は指数部を FP32 と共有していて、表現できる最小値〜最大値のレンジはほぼ同じです(ウィキペディア) -
学習が FP16 より安定しやすいケースが多い
FP16 はレンジが狭いので、大きな勾配や活性が出ると簡単にInfやNaNになりがちです。BF16 はレンジが広いので、多少無茶な値でもとりあえず表現できてしまう -
ハードウェアとフレームワークのサポート
Cloud TPU での BF16 サポートが先行し、その後 NVIDIA GPU でも Ampere (A100 など) 以降でネイティブ BF16 サポートが入り、PyTorch / JAX などもautocast(dtype=torch.bfloat16)のように簡単に使えるようになった(Google Cloud)
このあたりの理由から、特に 事前学習や大規模な教師あり学習では「とりあえず BF16」がデファクト になりつつあります。
さっきの RL 論文は何を言っているのか
Sea AI Lab と NUS のグループによる
「Defeating the Training-Inference Mismatch via FP16」 という 2025 年のプレプリントです。(arXiv)
翻訳して要約すると、論文の主張はこうです。
-
LLM を強化学習 (RL) でファインチューニングするとき
- 推論用エンジン (ロールアウトを出す部分)
- 学習用エンジン (勾配を計算する部分)
が別実装になっていることが多い
-
両方 BF16 で計算するものの、丸め誤差や実装の違いで
推論ポリシーと学習ポリシーが数値的にズレる -
そのズレが RL の学習を不安定にし、報酬が落ちたり collapse したりする
-
そこで BF16 をやめて FP16 に変えたところ
- 両エンジンの出力がずっと近くなり
- 学習が安定し
- 収束も速くなった
-
しかも、変更は「precision を FP16 にする」程度で、アルゴリズムの変更は不要
つまり
BF16 の「レンジは広いけど丸め誤差が大きい」という性質が
RL の「学習時と推論時で同じポリシーを前提とする」性質と相性悪かった
という話です。(arXiv)
引用ツイートを日本語にすると
元のツイート風の文を日本語にすると、だいたいこんな感じになります。
🚀 新しい研究を共有します
💊 問題: BF16 精度は、学習と推論の間に大きなミスマッチを生み、RL の学習を不安定にします
💡 解決: FP16 に切り替えるだけです
🎯 以上です
だいぶ勢いのある言い方ですが、論文自体はかなり真面目に
- いろいろな RL アルゴリズム
- いろいろなモデル (Dense、MoE、LoRA など)
- 複数のフレームワーク
で実験していて、「BF16 より FP16 の方が安定して良くなるケースが多い」と報告しています。(arXiv)
なんで FP16 にするとマシになるのか(直感的な説明)
BF16 と FP16 の一番の違いは 仮数部の bit 数 です。
この差は、一回の計算だけ見れば大したことないように思えますが、LLM の RL では
- ログ確率や確率の比
- 長いトークン列に渡る積や和
- 重要度サンプリングの比率
みたいな「ちょっとした差がどんどん増幅される」処理を延々と繰り返します。
ここで
- 推論エンジン A (BF16、実装 X)
- 学習エンジン B (BF16、実装 Y)
の両方を通したときの誤差が、それぞれそこそこ大きいと
- A を通した確率列
- B を通した確率列
が、同じパラメータなのに統計的に結構違うもの になってしまいます。
FP16 にすると、仮数部が増えて丸め誤差が小さくなるので
- A で計算した結果
- B で計算した結果
の差がかなり縮み、トレーニングポリシーと推論ポリシーのミスマッチも減る、というのが論文の主張です。(arXiv)
実務で precision をどう選べばよいか
じゃあ普段の学習ではどうすればいいのかをチェックリストにしてみます。
1. まずは「何をしているコードか」を見る
-
素の教師あり学習や事前学習
→ BF16 か FP16 どちらでも良いことが多い -
RL、数値的にシビアな最適化、重要度サンプリングを多用
→ 論文に倣って FP16 を検討する価値がある(arXiv) -
既に使っているフレームワークが「これは BF16 を前提に調整している」と明言している
→ まずはそのまま従う
2. GPU の種類とフレームワークのサポート
-
A100 / H100 / RTX 40 系など、BF16 ネイティブ対応の GPU
- BF16 / FP16 のどちらも速度は十分速いことが多い
-
もう少し古い GPU
- BF16 のサポートが怪しい場合があるので、FP16 の方が安定して速いケースもある(PyTorch Forums)
使っているフレームワークの precision オプションの説明は一度ちゃんと読むのがおすすめです。
3. 典型的な設定の例(PyTorch)
PyTorch なら、学習ループの中で autocast の dtype を変えるだけで BF16 / FP16 を切り替えられます。
import torch
from torch.amp import autocast, GradScaler
model = ...
optimizer = ...
use_bf16 = True # False にすると FP16
dtype = torch.bfloat16 if use_bf16 else torch.float16
scaler = GradScaler(enabled=not use_bf16) # BF16 のときはスケーラ不要なことが多い
for inputs, labels in dataloader:
optimizer.zero_grad()
with autocast(device_type="cuda", dtype=dtype):
outputs = model(inputs)
loss = loss_fn(outputs, labels)
# FP16 のときだけスケーラを使う
if use_bf16:
loss.backward()
optimizer.step()
else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
経験的には
- BF16: loss スケーリング無しでそのまま学習しやすい
- FP16:
GradScalerなどでの loss スケーリングがほぼ必須
という感触です。
どの precision を選べばよいかのまとめ
選び方をもう一度以下に示します。
-
とにかくメモリを節約して速くしたい
→ BF16 か FP16 のどちらかを使う -
レンジ広めで学習を安定させたい (特に pretrain 系)
→ BF16 をまず試す -
RL で「学習時と推論時のミスマッチ」が問題になっていそう
→ 今回の論文のように FP16 にしてみる価値がある(arXiv) -
どれを選べばいいか分からないけど、とりあえず壊したくない
→ まず FP32 で動かしてから、BF16 / FP16 に落として挙動を比較する
🐣 precision を変えるだけで学習が爆発したり安定したりするので、数値フォーマットは地味に侮れないなあと日々思っています。
おわりに
GPU 学習で出てくる BF16 / FP16 / FP32 は、単に「ビット数が違う」だけでなく
- ダイナミックレンジ
- 丸め誤差の大きさ
- 学習と推論のミスマッチ
といった形で学習挙動に影響します。
特に BF16 はレンジ広めで扱いやすい一方、RL のような数値的にシビアな場面では、今回の論文のように FP16 の方が良い場合もあります。
precision の設定は「なんとなくデフォルトのまま」ではなく、タスクとフレームワークに応じて意識的に選んだ方が、無駄なハマりポイントを減らせるはずです。
🐣 4bit 量子化などに比べれば 16bit と 32bit の違いなんて微々たるものかも!?