TL;DR(要点まとめ)
- FSDP2はPyTorchの最新ZeRO-3実装で、限られたGPUメモリで大規模モデル学習を可能にする
- 7Bモデルでも通常110GB+のメモリが必要だが、FSDP2により複数GPUで分散可能
- 重要なのは初期化手順:Meta-device → シャーディング → 重み読み込み
- 初期化後の学習ループは通常のPyTorchと同じコードで動作
ZeRO-3の基礎: FSDP2の土台となる概念
大規模言語モデルの学習時、GPUメモリは以下の要素で消費される:
- モデルパラメータ: 7Bモデルで約14GB(bf16)
- 勾配: パラメータと同サイズで約14GB(bf16)
- オプティマイザ状態: パラメータの12倍で約84GB(fp32)
- 活性化値: バッチサイズとシーケンス長に依存
7Bモデルでは、GPU1台あたり110GB+のメモリが必要になる!
なぜオプティマイザ状態が12倍なのか?
Adam/AdamWオプティマイザは、FP32でパラメータの3つのコピーを保持する:
- マスターコピー: パラメータ本体(4バイト/パラメータ)
- モーメンタム状態: 1次モーメント(4バイト/パラメータ)
- 分散状態: 2次モーメント(4バイト/パラメータ)
合計: 4 + 4 + 4 = 12バイト/パラメータ
FSDP2のメモリ使用量公式
FSDP2フルシャーディング時のGPU1台あたりのピークメモリ使用量は以下で推定できる:
M = \frac{16P}{N} + 2bshL + O
ここで:
- P: 総モデルパラメータ数
- N: GPU数
- 16バイト/パラメータ: 2(パラメータbf16)+ 2(勾配bf16)+ 12(オプティマイザfp32)
- 2bshL: 勾配チェックポイント使用時の活性化メモリ(bf16)
- O: システムオーバーヘッド(約6.5GB)
重要な洞察: FSDP2は16PをN個のGPUで分散するため、大規模モデルが実現可能になる!
FSDP2: PyTorchの最新ZeRO-3実装
FSDP2はPyTorchによるZeRO-3概念の最新実装で、大幅な改善を含む:
- DTensorベース: パラメータ単位のシャーディングでより細かい制御
- 組み合わせ可能: 他の並列化戦略との簡単な統合
- 高性能: 通信オーバーラップの改善とオーバーヘッド削減
重要な部分: モデル初期化
FSDP2の複雑さは主にモデル初期化にある。正しく設定すれば、学習は通常のPyTorchと同じ!
ステップ1: デバイスメッシュのセットアップ
from torch.distributed.device_mesh import init_device_mesh
# 分散環境の初期化
dist.init_process_group(backend="nccl")
mesh = init_device_mesh("cuda", (world_size,)) # 純粋なFSDP用の1Dメッシュ
ステップ2: Meta-Device初期化(大規模モデル用)
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
# 重みなしでコンフィグを読み込み
config = AutoConfig.from_pretrained(model_name)
config.torch_dtype = torch.bfloat16
# meta deviceでモデル構造を作成(メモリ割り当てなし)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
なぜMeta-Device? 7Bモデルを通常通り読み込むと各GPUで14GB+のRAMが必要。Meta-deviceはメモリ割り当てなしで構造のみ作成する。
ステップ3: FSDP2シャーディングの適用
from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp.api import MixedPrecisionPolicy
# FSDP2パラメータの設定
fsdp_kwargs = {
"mesh": mesh,
"reshard_after_forward": True, # ZeRO-3動作: forward後に再シャード
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # 勾配により高い精度
cast_forward_inputs=True,
),
}
# ボトムアップでシャーディング適用: まずレイヤー、次にルート
for i, layer in enumerate(model.model.layers):
model.model.layers[i] = fully_shard(layer, **fsdp_kwargs)
model = fully_shard(model, **fsdp_kwargs) # ルートシャーディング
重要な原則: 小さなモジュールではなく、Transformerレイヤーレベルでシャード。通信効率とメモリ節約のバランスを取る。
ステップ4: 事前学習済み重みの読み込み
from torch.distributed.checkpoint.state_dict import set_model_state_dict, StateDictOptions
# rank 0でのみ重みを読み込み
if local_rank == 0:
temp_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="cpu"
)
state_dict = temp_model.state_dict()
del temp_model
else:
state_dict = None
# 自動的にブロードキャストとシャード
set_model_state_dict(
model,
model_state_dict=state_dict,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
ステップ5: 勾配チェックポイントの有効化(オプション)
# HuggingFaceモデルの場合、読み込み後に有効化
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
学習: 通常のPyTorchと同じ!
初期化が完了すれば、学習は通常と同じ:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for batch in dataloader:
outputs = model(batch)
loss = compute_loss(outputs, batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
実現される主要メリット
✅ 馴染みのあるAPI: 学習ループは通常のPyTorchから変更なし
✅ 自動シャーディング: FSDP2が全ての通信を自動処理
✅ 柔軟性: 勾配チェックポイントなど他の最適化との簡単な組み合わせ
学習の実行
torchrun --nproc_per_node=4 your_training_script.py
まとめ
FSDP2は以下により大規模モデル学習を可能にする:
- ZeRO-3の理解: 全てをシャードし、必要時に収集
- 慎重な初期化: Meta-device → シャード → 重み読み込み
- レイヤーレベルシャーディング: 効率性とメモリのバランス
- 通常の学習: 設定後は単なるPyTorch!
複雑さは初期化に集中しているが、見返りは大きい:GPUに収まらないモデルの学習が、学習ループへの最小限のコード変更で実現できる。
質問やコメントがあれば、ぜひお聞かせください!