0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FSDP2で大規模モデル学習: ZeRO-3の数学的解説と実装ガイド

Posted at

TL;DR(要点まとめ)

  • FSDP2はPyTorchの最新ZeRO-3実装で、限られたGPUメモリで大規模モデル学習を可能にする
  • 7Bモデルでも通常110GB+のメモリが必要だが、FSDP2により複数GPUで分散可能
  • 重要なのは初期化手順:Meta-device → シャーディング → 重み読み込み
  • 初期化後の学習ループは通常のPyTorchと同じコードで動作

ZeRO-3の基礎: FSDP2の土台となる概念

大規模言語モデルの学習時、GPUメモリは以下の要素で消費される:

  • モデルパラメータ: 7Bモデルで約14GB(bf16)
  • 勾配: パラメータと同サイズで約14GB(bf16)
  • オプティマイザ状態: パラメータの12倍で約84GB(fp32)
  • 活性化値: バッチサイズとシーケンス長に依存

7Bモデルでは、GPU1台あたり110GB+のメモリが必要になる!

なぜオプティマイザ状態が12倍なのか?

Adam/AdamWオプティマイザは、FP32でパラメータの3つのコピーを保持する:

  • マスターコピー: パラメータ本体(4バイト/パラメータ)
  • モーメンタム状態: 1次モーメント(4バイト/パラメータ)
  • 分散状態: 2次モーメント(4バイト/パラメータ)

合計: 4 + 4 + 4 = 12バイト/パラメータ

FSDP2のメモリ使用量公式

FSDP2フルシャーディング時のGPU1台あたりのピークメモリ使用量は以下で推定できる:

M = \frac{16P}{N} + 2bshL + O

ここで:

  • P: 総モデルパラメータ数
  • N: GPU数
  • 16バイト/パラメータ: 2(パラメータbf16)+ 2(勾配bf16)+ 12(オプティマイザfp32)
  • 2bshL: 勾配チェックポイント使用時の活性化メモリ(bf16)
  • O: システムオーバーヘッド(約6.5GB)

重要な洞察: FSDP2は16PをN個のGPUで分散するため、大規模モデルが実現可能になる!

FSDP2: PyTorchの最新ZeRO-3実装

FSDP2はPyTorchによるZeRO-3概念の最新実装で、大幅な改善を含む:

  • DTensorベース: パラメータ単位のシャーディングでより細かい制御
  • 組み合わせ可能: 他の並列化戦略との簡単な統合
  • 高性能: 通信オーバーラップの改善とオーバーヘッド削減

重要な部分: モデル初期化

FSDP2の複雑さは主にモデル初期化にある。正しく設定すれば、学習は通常のPyTorchと同じ!

ステップ1: デバイスメッシュのセットアップ

from torch.distributed.device_mesh import init_device_mesh

# 分散環境の初期化
dist.init_process_group(backend="nccl")
mesh = init_device_mesh("cuda", (world_size,))  # 純粋なFSDP用の1Dメッシュ

ステップ2: Meta-Device初期化(大規模モデル用)

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

# 重みなしでコンフィグを読み込み
config = AutoConfig.from_pretrained(model_name)
config.torch_dtype = torch.bfloat16

# meta deviceでモデル構造を作成(メモリ割り当てなし)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

なぜMeta-Device? 7Bモデルを通常通り読み込むと各GPUで14GB+のRAMが必要。Meta-deviceはメモリ割り当てなしで構造のみ作成する。

ステップ3: FSDP2シャーディングの適用

from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp.api import MixedPrecisionPolicy

# FSDP2パラメータの設定
fsdp_kwargs = {
    "mesh": mesh,
    "reshard_after_forward": True,  # ZeRO-3動作: forward後に再シャード
    "mp_policy": MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,  # 勾配により高い精度
        cast_forward_inputs=True,
    ),
}

# ボトムアップでシャーディング適用: まずレイヤー、次にルート
for i, layer in enumerate(model.model.layers):
    model.model.layers[i] = fully_shard(layer, **fsdp_kwargs)

model = fully_shard(model, **fsdp_kwargs)  # ルートシャーディング

重要な原則: 小さなモジュールではなく、Transformerレイヤーレベルでシャード。通信効率とメモリ節約のバランスを取る。

ステップ4: 事前学習済み重みの読み込み

from torch.distributed.checkpoint.state_dict import set_model_state_dict, StateDictOptions

# rank 0でのみ重みを読み込み
if local_rank == 0:
    temp_model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16, device_map="cpu"
    )
    state_dict = temp_model.state_dict()
    del temp_model
else:
    state_dict = None

# 自動的にブロードキャストとシャード
set_model_state_dict(
    model,
    model_state_dict=state_dict,
    options=StateDictOptions(
        full_state_dict=True,
        broadcast_from_rank0=True,
    ),
)

ステップ5: 勾配チェックポイントの有効化(オプション)

# HuggingFaceモデルの場合、読み込み後に有効化
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

学習: 通常のPyTorchと同じ!

初期化が完了すれば、学習は通常と同じ:

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for batch in dataloader:
    outputs = model(batch)
    loss = compute_loss(outputs, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

実現される主要メリット

馴染みのあるAPI: 学習ループは通常のPyTorchから変更なし
自動シャーディング: FSDP2が全ての通信を自動処理
柔軟性: 勾配チェックポイントなど他の最適化との簡単な組み合わせ

学習の実行

torchrun --nproc_per_node=4 your_training_script.py

まとめ

FSDP2は以下により大規模モデル学習を可能にする:

  1. ZeRO-3の理解: 全てをシャードし、必要時に収集
  2. 慎重な初期化: Meta-device → シャード → 重み読み込み
  3. レイヤーレベルシャーディング: 効率性とメモリのバランス
  4. 通常の学習: 設定後は単なるPyTorch!

複雑さは初期化に集中しているが、見返りは大きい:GPUに収まらないモデルの学習が、学習ループへの最小限のコード変更で実現できる。


質問やコメントがあれば、ぜひお聞かせください!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?