Andrej Karpathy「Let's build GPT」解説シリーズ 第4動画
はじめに
前回は、基本的な学習ループを実装し、モデルが学習できることを確認しました。しかし、GPT-2のような巨大なモデルを効率的かつ安定的に学習させるには、さらなる工夫が必要です。単純な学習ループでは、計算速度が遅すぎたり、学習が不安定になったりする問題に直面します。
今回は、高度な学習テクニックを導入し、我々のGPTモデルを本格的なものにアップグレードしていきます。具体的には、以下のテクニックを実装します。
- 重み共有 (Weight Sharing)
- 適切な重み初期化
- 混合精度学習 (
autocast
) とtorch.compile
- Flash Attention
- 学習率スケジューラと勾配クリッピング
なぜ高度な学習テクニックが必要なのか?
大規模言語モデルの学習は、限られた計算リソースの中で、いかに効率よく、安定して学習を進めるかが重要です。最高のパフォーマンスを引き出すためには、モデルアーキテクチャやデータ、オプティマイザの選択だけでなく、さまざまな学習テクニックによる細かな調整が不可欠です。これらのテクニックは、「より速く、より安定して、より賢く」学習を進めるための重要な役割を担っています。
例えば、計算の精度を少し落として(混合精度学習)速度を稼いだり、学習率をウォームアップさせたりすることで、学習の序盤での失敗を防ぎ、最終的により良い結果にたどり着くことができます。
具体的な実装
1. 重み共有 (Weight Tying)
これは、モデルのパラメータ数を削減し、学習を安定させるための古典的かつ効果的なテクニックです。
-
wte
(Token Embedding層): トークンIDをベクトルに変換します。 -
lm_head
(出力層): ベクトルをトークンIDの確率分布に変換します。
この2つの層は、本質的に逆の操作を行っています。そのため、これらの層の重み行列を共有(lm_head
の重みをwte
の重みの転置として使う)することで、意味的な一貫性が生まれ、パラメータ数を大幅に削減できます。
# GPTクラスの__init__内に追加
class GPT(nn.Module):
def __init__(self, config):
# ... (前回の実装)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# lm_headの重みをwteと共有する
self.transformer.wte.weight = self.lm_head.weight
2. 適切な重み初期化
ニューラルネットワークの初期値は、学習の成否を大きく左右します。PyTorchのデフォルト初期化は汎用的ですが、GPT-2のようなTransformerには専用の初期化が効果的です。
GPT-2の論文では、重みを平均0、標準偏差0.02の正規分布で初期化し、さらに残差接続の一部を特別にスケーリングすることが推奨されています。
# GPTクラスに__init_weightsメソッドを追加し、__init__の最後に呼び出す
class GPT(nn.Module):
def __init__(self, config):
# ...
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
# 残差接続の出力層のスケーリング
if hasattr(module, 'NANOGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# Block内のc_projとMLP内のc_projにフラグを立てる
self.c_proj.NANOGPT_SCALE_INIT = 1
この初期化により、学習開始時の出力が爆発したり消失したりするのを防ぎ、学習を安定させます。
なぜ残差接続の出力層のスケーリングが必要?
残差接続では、各層の出力が累積的に加算されていきます。もし各層の出力が大きすぎると、層を重ねるごとに値がどんどん大きくなり、最終的に無限大に発散してしまいます。
出力層のスケーリングとは、層の重み初期値を小さくすることで出力の大きさ(分散)を抑制し、この発散を防ぐテクニックです。
3. 高速化: autocast
, torch.compile
, Flash Attention
a. 混合精度学習 (autocast
)
通常、計算は32ビット浮動小数点数(FP32)で行われますが、これを16ビット(FP16やBF16)に落とすことで、計算速度を大幅に向上させ、GPUメモリの使用量を削減できます。torch.autocast
コンテキストマネージャを使うだけで、これを自動的に適用できます。
# 学習ループ内
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps # スケーリングはautocastの外で
loss.backward()
b. torch.compile
PyTorch 2.0から導入されたJIT(Just-In-Time)コンパイラです。モデルの計算グラフを解析し、Pythonのオーバーヘッドを削減したり、複数の計算を1つにまとめるカーネルフュージョンを行ったりすることで、劇的に実行速度を向上させます。
model = GPT(GPTConfig())
if use_compile:
model = torch.compile(model)
c. Flash Attention
Attention計算は、シーケンス長(T)の2乗に比例してメモリと計算量が増加するボトルネックです。Flash Attentionは、Attentionの計算方法を工夫し、巨大なT x T
のAttention行列をGPUメモリ上に作ることなく計算する手法です。
PyTorch 2.0以降では、F.scaled_dot_product_attention
としてネイティブにサポートされており、is_causal=True
フラグを立てるだけでCausal Attentionを高速に実行できます。
# CausalSelfAttentionのforwardを書き換え
def forward(self, x):
# ... q, k, vの準備 ...
# 以下の数行を置き換える
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v
# これ1行でOK
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# ... headの結合 ...
return y
4. 学習の安定化: 学習率スケジューラと勾配クリッピング
a. 学習率スケジューラ (Learning Rate Scheduler)
学習率を固定するのではなく、学習の進捗に応じて変化させることで、より良い性能を達成できます。一般的には「ウォームアップ付きコサイン減衰」が使われます。
- ウォームアップ: 学習初期は小さな学習率から始め、徐々に目標値まで上げる。これにより、学習序盤の不安定な時期を乗り切る。
- コサイン減衰: ウォームアップ後は、コサインカーブに沿って学習率を滑らかに下げていく。これにより、学習の終盤で細やかな調整が可能になる。
b. 勾配クリッピング (Gradient Clipping)
学習中に勾配が異常に大きくなる「勾配爆発」を防ぐための安全装置です。全ての勾配の大きさ(ノルム)を計算し、しきい値(例: 1.0)を超えていたら、全体の勾配を縮小してしきい値以下に収めます。
# optimizer.step()の前に追加
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
まとめ
今回は、GPTモデルの学習を向上させる様々なテクニックを学びました。これらの最適化により、学習はより速く、安定し、最終的により高い性能を達成することが可能になります。
しかし、まだ課題は残っています。それはスケールの問題です。現在の実装は1つのGPUで動作しますが、GPT-2のような超巨大モデルは、単一のGPUではメモリが足りず、学習に何年もかかってしまいます。
次回は、「大規模学習へのスケールアップ編」として、複数のGPUを使って学習を分散・高速化する分散並列学習(DDP)と、GPUメモリに収まらない巨大なバッチサイズを扱うための勾配累積(Gradient Accumulation)、そして巨大なデータセットを扱うためのシャーディングについて解説します。
(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)