0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

GPTをゼロから実装して理解してみる(第9部:大規模学習へのスケールアップ編)

Last updated at Posted at 2025-09-06

Andrej Karpathy「Let's build GPT」解説シリーズ 第4動画

はじめに

前回は、重み初期化やFlash Attention、学習率スケジューラといった高度なテクニックを導入し、単一GPUでの学習を最適化しました。しかし、現代の言語モデルは数十億〜数兆パラメータに達し、学習には膨大な計算資源が必要です。単一GPUのメモリや計算能力には限界があります。

今回は、その壁を乗り越え、学習をスケールアップさせるための2つの重要な技術、勾配累積(Gradient Accumulation)と分散並列学習(Distributed Data Parallel, DDP)を実装します。さらに、巨大なデータセットを効率的に扱うためのデータシャーディングについても解説します。

なぜ学習をスケールアップさせる必要があるのか?

なぜスケールアップが必要なのでしょうか?

  1. 巨大なバッチサイズ: GPT-3の研究では、数百万トークンというとてつもなく大きなバッチサイズが性能向上に寄与したことが示されています。しかし、これをそのままGPUメモリに乗せることは不可能です。
  2. 学習時間の短縮: 1つのGPUで100年かかる計算も、100個のGPUを使えば1年で終わらせることができます(理論上は)。

この課題を解決するのが、勾配累積とDDPです。

  • 勾配累積: 1台のGPUで、メモリに乗る小さな「マイクロバッチ」を何回か処理し、勾配だけを足し合わせます。これにより、見かけ上大きなバッチサイズで学習したのと同じ効果を、少ないメモリで実現します。
  • DDP: 複数のGPUでそれぞれがデータの異なる部分を並行して学習します。各GPUで計算された勾配は、全GPUで同期・平均化され、全員が同じように賢くなっていきます。これにより、学習時間を大幅に短縮します。

具体的な実装

1. 勾配累積 (Gradient Accumulation)

これは、GPUメモリの制約内で実質的なバッチサイズを増やすためのテクニックです。

例えば、目標のバッチサイズ(total_batch_size)が524,288トークンで、1度に処理できるマイクロバッチ(B * T)が16,384トークンだとします。この場合、524,288 / 16,384 = 32回、勾配を累積すれば、目標のバッチサイズと同じだけのデータで学習したことになります。

実装は、学習ループを入れ子にするだけです。

total_batch_size = 524288
B = 16 # マイクロバッチのサンプル数
T = 1024 # シーケンス長
assert total_batch_size % (B * T) == 0
grad_accum_steps = total_batch_size // (B * T) # 32回

# ... 学習ループ ...
for step in range(max_steps):
    optimizer.zero_grad() # 勾配を外側のループでリセット
    
    # --- 勾配累積ループ --- 
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        # ... forward & backward ...
        loss = loss / grad_accum_steps # 平均を取るために損失をスケーリング
        loss.backward()
    # ---------------------
    
    # 累積した勾配で重みを更新
    optimizer.step()

2. 分散並列学習 (Distributed Data Parallel, DDP)

DDPは、PyTorchでマルチGPU学習を行うための標準的な方法です。torchrunコマンドを使って起動し、各プロセス(通常は各GPUに1つ)が同じコードを実行します。

DDPのセットアップ

import os
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

ddp = int(os.environ.get('RANK', -1)) != -1

if ddp:
    # DDP環境では、RANK, LOCAL_RANK, WORLD_SIZEが環境変数として設定される
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK']) # 全プロセスの中でのランク
    ddp_local_rank = int(os.environ['LOCAL_RANK']) # このマシン内でのランク(GPU番号)
    ddp_world_size = int(os.environ['WORLD_SIZE']) # 全プロセス数
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0 # ランク0のプロセスがログ出力などを担当
else:
    # DDPでない場合の設定
    # ...

# モデルをDDPでラップする
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model # ラップ解除した素のモデル

DDPと勾配累積の連携

勾配累積中に毎回GPU間で勾配を同期(All-Reduce)するのは非効率です。DDPでは、model.require_backward_grad_syncFalseに設定することで、backward()時の勾配同期を一時的にオフにできます。最後のマイクロステップでのみ同期を有効にすればOKです。

# 勾配累積ループ内
for micro_step in range(grad_accum_steps):
    # 最後のマイクロステップでのみ勾配を同期する
    if ddp:
        model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
    # ... forward & backward ...

3. 巨大データセットのためのデータシャーディング

FineWebのような100億トークンを超えるデータセットは、単一のマシンのメモリやディスクには収まりきりません。そこで、データを小さなファイル(シャード)に分割して保存します。

fineweb.pyは、まさにこのためのスクリプトです。

# fineweb.py

import os
import multiprocessing as mp
import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm

local_dir = "edu_fineweb10B"
remote_name = "sample-10BT"
shard_size = int(1e8)  # 1億トークンずつシャードに分割

DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")

enc = tiktoken.get_encoding("gpt2")
eot = enc._special_tokens['<|endoftext|>']

def tokenize(doc):
    tokens = [eot]  # 各文書の区切りに<|endoftext|>トークンを挿入
    tokens.extend(enc.encode_ordinary(doc["text"]))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def write_datafile(filename, tokens_np):
    np.save(filename, tokens_np)

nprocs = max(1, os.cpu_count()//2)
with mp.Pool(nprocs) as pool:
    shard_index = 0
    all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
    token_count = 0
    progress_bar = None
    
    for tokens in pool.imap(tokenize, fw, chunksize=16):
        if token_count + len(tokens) < shard_size:
            all_tokens_np[token_count:token_count+len(tokens)] = tokens
            token_count += len(tokens)
            if progress_bar is None:
                progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
            progress_bar.update(len(tokens))
        else:
            # 現在のシャードを保存し、新しいシャードを開始
            split = "val" if shard_index == 0 else "train"
            filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
            remainder = shard_size - token_count
            progress_bar.update(remainder)
            all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
            write_datafile(filename, all_tokens_np)
            shard_index += 1
            progress_bar = None
            # 残りのトークンを次のシャードに
            all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
            token_count = len(tokens)-remainder

    # 最後のシャードを保存
    if token_count != 0:
        split = "val" if shard_index == 0 else "train"
        filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
        write_datafile(filename, all_tokens_np[:token_count])
  • datasetsライブラリでデータをストリーミングダウンロード
  • multiprocessingで複数CPUコアを使い、並列でトークン化
  • 一定サイズ(例: 1億トークン)ごとにwrite_datafile.npyファイルとしてシャードを保存

そして、DataLoaderLiteをこのシャード構造に対応させます。

# DataLoaderLiteの改良
class DataLoaderLite:
    def __init__(self, ..., split):
        # ...
        # "train"または"val"のシャードファイルリストを取得
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        self.shards = sorted(shards)
        self.reset()

    def reset(self):
        # 最初のシャードから開始
        self.current_shard = 0
        self.tokens = self.load_tokens(self.shards[self.current_shard])
        # ...

    def next_batch(self):
        # ...
        # もし現在のシャードを使い切ったら、次のシャードを読み込む
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = self.load_tokens(self.shards[self.current_shard])
            # ...
        return x, y

実行と検証

これらの仕組みをすべて統合したtrain_gpt2.pyは、以下のコマンドで8GPUを使って分散学習を実行できます。

torchrun --standalone --nproc_per_node=8 train_gpt2.py
  • torchrun: PyTorchの分散学習起動ツール。
  • --standalone: 単一ノード(1台のマシン)での実行を指示。
  • --nproc_per_node=8: 1ノードあたり8つのプロセス(=8GPU)を起動。

よくあるミス

  • データローダの重複: DDP環境では、各プロセスがデータセットの異なる部分を処理するようにDataLoaderを設計する必要があります。そうしないと、全GPUが同じデータを処理してしまい、分散学習の意味がなくなります。今回のDataLoaderLiteでは、process_rankに応じて初期位置をずらすことでこれを実現しています。
  • マスタープロセスでのみ実行: ログの書き込みやモデルのチェックポイント保存は、複数のプロセスが同時に行うとファイルが破損する可能性があるため、master_process(通常はランク0)のみで行うように制御することが重要です。

まとめ

今回は、勾配累積とDDPを導入することで、GPTモデルの学習をマルチGPU環境へとスケールアップさせました。また、シャーディングにより、メモリに収まらない巨大なデータセットを扱えるようになりました。

しかし、モデルをただ学習させるだけでは、その性能を客観的に評価することはできません。最終回となる次回は、「モデルの評価と検証編」として、学習中に検証用データセットで性能をモニタリングする方法や、HellaSwagのような標準ベンチマークを使ってモデルの能力を評価する手法について解説します。
(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)

参考動画・資料

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?