0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

GPTをゼロから実装して理解してみる(第8部:高度な学習テクニック編)

Last updated at Posted at 2025-09-06

Andrej Karpathy「Let's build GPT」解説シリーズ 第4動画

はじめに

前回は、基本的な学習ループを実装し、モデルが学習できることを確認しました。しかし、GPT-2のような巨大なモデルを効率的かつ安定的に学習させるには、さらなる工夫が必要です。単純な学習ループでは、計算速度が遅すぎたり、学習が不安定になったりする問題に直面します。

今回は、高度な学習テクニックを導入し、我々のGPTモデルを本格的なものにアップグレードしていきます。具体的には、以下のテクニックを実装します。

  • 重み共有 (Weight Sharing)
  • 適切な重み初期化
  • 混合精度学習 (autocast) と torch.compile
  • Flash Attention
  • 学習率スケジューラと勾配クリッピング

なぜ高度な学習テクニックが必要なのか?

大規模言語モデルの学習は、限られた計算リソースの中で、いかに効率よく、安定して学習を進めるかが重要です。最高のパフォーマンスを引き出すためには、モデルアーキテクチャやデータ、オプティマイザの選択だけでなく、さまざまな学習テクニックによる細かな調整が不可欠です。これらのテクニックは、「より速く、より安定して、より賢く」学習を進めるための重要な役割を担っています。

例えば、計算の精度を少し落として(混合精度学習)速度を稼いだり、学習率をウォームアップさせたりすることで、学習の序盤での失敗を防ぎ、最終的により良い結果にたどり着くことができます。

具体的な実装

1. 重み共有 (Weight Tying)

これは、モデルのパラメータ数を削減し、学習を安定させるための古典的かつ効果的なテクニックです。

  • wte (Token Embedding層): トークンIDをベクトルに変換します。
  • lm_head (出力層): ベクトルをトークンIDの確率分布に変換します。

この2つの層は、本質的に逆の操作を行っています。そのため、これらの層の重み行列を共有(lm_headの重みをwteの重みの転置として使う)することで、意味的な一貫性が生まれ、パラメータ数を大幅に削減できます。

# GPTクラスの__init__内に追加
class GPT(nn.Module):
    def __init__(self, config):
        # ... (前回の実装)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # lm_headの重みをwteと共有する
        self.transformer.wte.weight = self.lm_head.weight

2. 適切な重み初期化

ニューラルネットワークの初期値は、学習の成否を大きく左右します。PyTorchのデフォルト初期化は汎用的ですが、GPT-2のようなTransformerには専用の初期化が効果的です。

GPT-2の論文では、重みを平均0、標準偏差0.02の正規分布で初期化し、さらに残差接続の一部を特別にスケーリングすることが推奨されています。

# GPTクラスに__init_weightsメソッドを追加し、__init__の最後に呼び出す
class GPT(nn.Module):
    def __init__(self, config):
        # ...
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            # 残差接続の出力層のスケーリング
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

# Block内のc_projとMLP内のc_projにフラグを立てる
self.c_proj.NANOGPT_SCALE_INIT = 1

この初期化により、学習開始時の出力が爆発したり消失したりするのを防ぎ、学習を安定させます。

なぜ残差接続の出力層のスケーリングが必要?

残差接続では、各層の出力が累積的に加算されていきます。もし各層の出力が大きすぎると、層を重ねるごとに値がどんどん大きくなり、最終的に無限大に発散してしまいます。

出力層のスケーリングとは、層の重み初期値を小さくすることで出力の大きさ(分散)を抑制し、この発散を防ぐテクニックです。

3. 高速化: autocast, torch.compile, Flash Attention

a. 混合精度学習 (autocast)

通常、計算は32ビット浮動小数点数(FP32)で行われますが、これを16ビット(FP16やBF16)に落とすことで、計算速度を大幅に向上させ、GPUメモリの使用量を削減できます。torch.autocastコンテキストマネージャを使うだけで、これを自動的に適用できます。

# 学習ループ内
with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
loss = loss / grad_accum_steps # スケーリングはautocastの外で
loss.backward()

b. torch.compile

PyTorch 2.0から導入されたJIT(Just-In-Time)コンパイラです。モデルの計算グラフを解析し、Pythonのオーバーヘッドを削減したり、複数の計算を1つにまとめるカーネルフュージョンを行ったりすることで、劇的に実行速度を向上させます。

model = GPT(GPTConfig())
if use_compile:
    model = torch.compile(model)

c. Flash Attention

Attention計算は、シーケンス長(T)の2乗に比例してメモリと計算量が増加するボトルネックです。Flash Attentionは、Attentionの計算方法を工夫し、巨大なT x TのAttention行列をGPUメモリ上に作ることなく計算する手法です。
PyTorch 2.0以降では、F.scaled_dot_product_attentionとしてネイティブにサポートされており、is_causal=Trueフラグを立てるだけでCausal Attentionを高速に実行できます。

# CausalSelfAttentionのforwardを書き換え
def forward(self, x):
    # ... q, k, vの準備 ...
    # 以下の数行を置き換える
    # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # y = att @ v
    
    # これ1行でOK
    y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
    
    # ... headの結合 ...
    return y

4. 学習の安定化: 学習率スケジューラと勾配クリッピング

a. 学習率スケジューラ (Learning Rate Scheduler)

学習率を固定するのではなく、学習の進捗に応じて変化させることで、より良い性能を達成できます。一般的には「ウォームアップ付きコサイン減衰」が使われます。

  • ウォームアップ: 学習初期は小さな学習率から始め、徐々に目標値まで上げる。これにより、学習序盤の不安定な時期を乗り切る。
  • コサイン減衰: ウォームアップ後は、コサインカーブに沿って学習率を滑らかに下げていく。これにより、学習の終盤で細やかな調整が可能になる。

cosine_decay.png

b. 勾配クリッピング (Gradient Clipping)

学習中に勾配が異常に大きくなる「勾配爆発」を防ぐための安全装置です。全ての勾配の大きさ(ノルム)を計算し、しきい値(例: 1.0)を超えていたら、全体の勾配を縮小してしきい値以下に収めます。

# optimizer.step()の前に追加
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()

まとめ

今回は、GPTモデルの学習を向上させる様々なテクニックを学びました。これらの最適化により、学習はより速く、安定し、最終的により高い性能を達成することが可能になります。

しかし、まだ課題は残っています。それはスケールの問題です。現在の実装は1つのGPUで動作しますが、GPT-2のような超巨大モデルは、単一のGPUではメモリが足りず、学習に何年もかかってしまいます。

次回は、「大規模学習へのスケールアップ編」として、複数のGPUを使って学習を分散・高速化する分散並列学習(DDP)と、GPUメモリに収まらない巨大なバッチサイズを扱うための勾配累積(Gradient Accumulation)、そして巨大なデータセットを扱うためのシャーディングについて解説します。
(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)

参考動画・資料

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?