0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【深層学習】大きなバッチサイズで学習効率化(半精度・勾配蓄積)

Last updated at Posted at 2025-09-15

こんにちは!大きなバッチサイズを扱いたいときのPyTorchのテクニックとして、① 半精度② 勾配蓄積の2つを取り上げます。

通常、GPUのメモリ数には限りがあり一気にデータを送ることは不可能なのでミニバッチにデータを分けて送ることになります。
その際、バッチサイズ=サンプル数が大きければデータのばらつきが小さくなるので 学習の安定化 に繫がります。
この記事では、バッチサイズの大きさを実現するためのテクニックとして以下を記載します。

  1. 半精度自動切替 (AMP)
  2. 勾配蓄積 (Gradient Accumulation)
  3. 半精度×勾配蓄積
  4. バッチサイズに応じて学習率も変化させたい

1. 半精度自動切替 (AMP)

torch.float32 で計算されていたものを一部の演算において、torch.float16(半精度)で計算するようにします。
線形層や畳み込み演算などはfloat16(半精度)で、縮約(mean等)にはfloat32で自動的に計算するというように、各演算を適切なデータ型にマッチングさせることができます。
通常、torch.autocasttorch.amp.GradScaler を組み合わせて使用しますが、二つともモジュール化されており、必要に応じて個別に使用もできます。

torch.autocast

前向き計算の時に、どの演算をどの精度で計算するかを自動で振り分けます。

  • 使い方
    前向きを with torch.amp.autocast('cuda', dtype=...)で包みます。
with torch.amp.autocast('cuda', dtype=torch.float16):
    logits = model(x)  # 速い演算は半精度、センシティブな演算はFP32に自動

torch.amp.GradScaler

半精度で生じやすい勾配のアンダーフロー(極小値が0)対策です。

  • 流れ
  1. 損失lossにスケール係数を掛ける
  2. 後向きbackward()
  3. 最適化直前に scaler.unscale_(optimizer) で元スケールへと戻す
  4. 勾配クリップ clip_grad_norm_
  5. scaler.step(optimizer)
  6. scaler.update() で増減を自動調整
  • 使い方
scaler = torch.amp.GradScaler('cuda')

for step, (x, y) in enumerate(train_loader, start=1):
    with torch.amp.autocast('cuda', dtype=torch.float16):
        loss = criterion(model(x), y)
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer) # ここで実勾配に戻す
    torch.nn.utils.clip_grad_norm_(params=model.parameters(), max_norm)  # 勾配クリップや勾配ノルム計算は `unscale_()` の後に行うことがポイント
    scaler.step(optimizer)
    scaler.update()
  • 勾配クリップ clip_grad_norm_ とは
    勾配ベクトルの長さが max_norm を超えたら全勾配を同じ比率で縮めて長さがちょうど max_norm となるようにします。これで、勾配爆発を防ぎます。

片方だけ使うパターン

訓練のときは両方使いますが、推論だけなら逆伝播がないので autocast だけでよいです。

    with torch.no_grad():
        for imgs, labels in val_loader:
            with torch.amp.autocast('cuda', dtype=torch.float16):
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model(imgs)
                loss = criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)

2. 勾配蓄積 (Gradient Accumulation)

一度に大きなバッチをGPUに載せられないので、複数の小バッチで勾配を貯めてからまとめて更新する方法です。

  • 流れ
  1. loss.backward()で勾配を足し続ける(何バッチ分貯めるかというハイパラ:accum_iter
  2. step % accum_iter == 0 のタイミングでだけ optimizer.step() と optimizer.zero_grad() を行う
accum_iter = 4
optimizer.zero_grad()

for step, (x,y) in enumerate(train_loader, start=1):
    logits = model(x)
    loss = criterion(logits, y)
    loss /=accum_iter # 蓄積回数で割る
    loss.backward() # 勾配を足す(更新はしない)

    # 蓄積の境目のとき
    if step % accum_iter == 0:
        optimizer.step()
        optimizer.zero_grad() # 次の勾配蓄積に向けてクリア

推論では勾配の計算はないので設定しません。

3. 半精度×勾配蓄積

二つを組み合わせてみます。

accum_iter = 4
optimizer.zero_grad()
scaler = torch.amp.GradScaler('cuda') 

for step, (x,y) in enumerate(train_loader, start=1):
    with torch.amp.autocast('cuda', dtype=torch.float16):
        logits = model(x)
        loss = criterion(logits, y)
        loss /=accum_iter # 蓄積回数で割る
    
    scaler.scale(loss).backward() # 蓄積
    
    if step % accum_iter == 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(params=model.parameters(), max_norm) 
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad() # 次の勾配蓄積に向けてクリア

NativeScalerWithGradNormCount

もうちょっと可読性を上げたい場合、このラッパはtorch.cuda.amp.GradScaler を内包しており以下を行ってくれます。

  • AMPの勾配スケーリング(scale/unscale/step/update)
  • 勾配ノルムの計測とクリップ
  • 勾配蓄積と相性のよい更新タイミング制御(update_grad)
class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler" # チェックポイント保存時のキー名

    def __init__(self, device_type='cuda', enabled=True):
        self._scaler = torch.amp.GradScaler(device_type) if enabled else torch.amp.GradScaler(device_type, enabled=False)

    def __call__(self, loss, optimizer, clip_grad=None, skip_grad=None, parameters=None, create_graph=False, update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph) # 勾配を貯める

        # 蓄積の境目なら更新する
        if update_grad:
            # 勾配クリップあり
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer) # 実勾配へ戻す
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) # 勾配ノルムの上限をclip_gradにそろえる
            # 勾配クリップなし、更新スキップあり
            elif skip_grad is not None:
                self._scaler.unscale_(optimizer) # 実勾配へ戻す
                norm = get_grad_norm_(parameters)
                
                # ノルムが閾値以上なら更新をスキップ
                if norm >= skip_grad:
                    self._scaler.update()
                    return norm #optimizer.step() を呼ばない

            # 勾配クリップなし、更新スキップなし
            else:
                self._scaler.unscale_(optimizer) # 実勾配へ戻す
                norm = get_grad_norm_(parameters)
                
            self._scaler.step(optimizer) # NaN/Infなら自動スキップ
            self._scaler.update() # スケール係数を動的調整
            
        # 蓄積ステップ(勾配を貯めるだけ)
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)

いろいろ条件分岐があります。
勾配クリップはmax_normを超えれば、縮める設定ですが、更新は実行されます。
更新スキップは更新もしないので、そのバッチの処理をまるごとなくすイメージです。

  • update_grad=False:貯めるだけ(backward のみ)
  • update_grad=True & clip_grad 指定:unscale → クリップ → step → update
  • update_grad=True & skip_grad 指定:unscale → ノルム測定 → 閾値超なら step をスキップ(updateだけ)
  • update_grad=True & どちらも指定なし:unscale → ノルム測定 → step → update
accum_iter = 4
scaler = NativeScalerWithGradNormCount()  # 中で GradScaler を保持している
optimizer.zero_grad()

for step, (x, y) in enumerate(train_loader, start=1):
    with torch.amp.autocast('cuda', dtype=torch.float16):
        logits  = model(x)
        loss = criterion(logits, y) / accum_iter

    # accum境目かどうかで update_grad を切り替え
    update_now = (step % accum_iter == 0) # Trueなら有効バッチの境目

    # clip したいときは引数に渡す(unscale_→clip→step→update まで一括)
    grad_norm = scaler(
        loss,
        optimizer,
        clip_grad=1.0, # クリップしないなら None
        parameters=model.parameters(),
        update_grad=update_now
    )

    if update_now:
        optimizer.zero_grad()

4. バッチサイズに応じて学習率も変化させたい

バッチサイズが大きいと勾配のばらつきが小さくなり、1回の学習率の更新を大きめにしても安定しやすくなります。
学習率をバッチの大きさに合わせて調整したいときは、まずは、1回の optimizer.step() で使うサンプル数(= eff_batch_size)を計算します。

eff_batch_size の計算

  • eff_batch_size = per_device_batch × accum_iter × world_size
    • per_device_batch:各GPU(各プロセス)の DataLoader が1イテレーションで取り出す個数
    • accum_iter:勾配蓄積の回数(何ミニバッチぶんを足してから1回だけ step するか)
    • world_size:分散の総プロセス数(= GPU数、単GPUなら1)

実際に使うベース学習率(固定)を定める

  • ベースLR = 基準LR × eff_batch_size / 256
    • eff_batch_sizeに合わせて基準LRをスケールした結果になります。256という数字は慣例のようです
    • 基準LRは基準の有効バッチサイズ(慣例で 256)に適した学習率で設計のものさし
    • ベースLRは実行前に一度だけ決める固定の学習率

もし学習中の時刻によって変化するスケジューラや層別の学習率を設定したい場合は、ベースLRに掛けます。

時刻依存LRを定める

  • 時刻依存LR = s(t) × ベースLR
    • s(t)はウォームアップ中や減衰中などで異なるスケジューラ

さらに、層別倍率も定める場合

  • 層別LR = 時刻依存LR × param_groupごとのlr_scale
    • lr_scaleはoptimizer を作る段階で、param group に {"lr_scale": 倍率}を入れておけば設定できる
base_lr = config['train']['lr']  # eff_batch でスケール済みのベースLR
wd = config['train']['weight_decay']

param_groups = [
    # 本体(バックボーン):倍率1.0
    {"params": (p for n,p in model.named_parameters() if not n.startswith("head") and p.requires_grad),
     "lr": base_lr, "weight_decay": wd, "lr_scale": 1.0},

    # ヘッド:倍率10.0
    {"params": (p for n,p in model.named_parameters() if     n.startswith("head") and p.requires_grad),
     "lr": base_lr, "weight_decay": wd, "lr_scale": 10.0},
]

optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.999))

まとめると、各層lの学習率は以下のようになります。

層lの学習率 = s(t) × m_l × base_lr
※ base_lr はeff_batch_size ( = batch × accum_iter × world_size ) で固定


以上です。読んでいただきありがとうございました。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?