Andrej Karpathy「Let's build GPT」解説シリーズ 第4動画
はじめに
前回は、GPTモデルの骨格となるTransformerアーキテクチャを実装しました。しかし、現状のモデルはランダムに初期化された「空っぽの脳」であり、意味のあるテキストを生成できません。今回は、このモデルに意味を持たせる学習のプロセスを実装します。
今回は、モデルを訓練するための基本的な要素である損失関数、オプティマイザ、そしてデータを供給するためのDataLoaderを実装していきます。
モデルを「学習」させるとはどういうことか?
モデルの学習とは、一言で言えば「モデルの予測と正解のズレ(=損失)を計算し、そのズレが小さくなるようにモデルの重みを少しずつ調整していく」作業です。
GPTのような言語モデルの場合、タスクは「ある文脈が与えられたとき、次に来る単語は何か?」を予測することです。学習データ(大量のテキスト)を使って、以下を繰り返します。
- モデルにテキストの一部(コンテキスト)を入力する。
- モデルが予測した「次の単語の確率分布」を出力する。
- 実際の「次の単語」(正解ラベル)と、モデルの予測を比較して、損失(Loss)を計算する。
- その損失を最小化するように、モデルの全パラメータ(重み)を微調整する(最適化)。
このプロセスを何回も繰り返すことで、モデルは徐々に言語のパターンを学習し、もっともらしい文章を生成できるようになります。
実装の全体像
学習プロセスを実装するために、以下のステップを踏みます。
- 損失の計算: 学習を始める前に、ランダムなモデルがどれくらいデタラメな予測をするか、初期損失を計算して理論値と比較します。
-
オプティマイザの導入: モデルの重みを更新するための最適化アルゴリズム(
AdamW
)を設定します。 -
DataLoaderの実装: 学習データを効率的にモデルに供給するためのシンプルな
DataLoaderLite
クラスを作成します。 - 学習ループ: 上記を組み合わせて、基本的な学習ループを実装します。
具体的な実装
1. 学習前の初期損失を確認する
学習を始める前に、ランダムに初期化されたモデルの性能を評価してみましょう。全く学習していないモデルの損失は、理論的に計算できます。
語彙数(vocab_size
)が50,257個ある場合、ランダムなモデルは全ての単語を等しい確率(1/50257)で予測するはずです。クロスエントロピー損失は -ln(正解の確率)
で計算されるため、理論的な初期損失は以下のようになります。
Loss = -ln(1 / 50257) ≈ 10.82
実際にコードで確認してみましょう。
import tiktoken
# 小さなテキストで試す
enc = tiktoken.get_encoding("gpt2")
text = "Hello, I'm a language model and I'm here to help you."
tokens = torch.tensor(enc.encode(text))
B, T = 4, 8 # バッチサイズ4, シーケンス長8
# 入力(x)とターゲット(y)を作成
buf = tokens[:B*T + 1]
x = buf[:-1].view(B, T) # 最初のトークンからB*T個
y = buf[1:].view(B, T) # 1つずらしたトークンが正解ラベル
# モデルの初期化
model = GPT(GPTConfig())
model.to(device)
# 損失の計算
logits, loss = model(x.to(device), y.to(device))
print(loss)
# 出力例:
# tensor(10.9239, device='mps:0', grad_fn=<NllLossBackward0>)
出力された損失は10.92
であり、理論値の10.82
に非常に近い値となりました。これは、モデルがまだ何も学習しておらず、ランダムな予測を行っていることを示しています。
2. オプティマイザと学習ステップ
損失が計算できたら、次はその損失を元にモデルの重みを更新するオプティマイザを導入します。ここでは、Transformerで広く使われているAdamW
(Adam with Weight Decay)を使用します。
オプティマイザについては第2部GPTをゼロから実装して理解してみる(第2部:Bigramモデルと基本的な言語モデル編)でも解説してますので、ご参照ください。
基本的な学習ステップは以下のようになります。
# オプティマイザの初期化
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# 50回学習ステップを回す
for i in range(50):
# 1. 勾配をリセット
optimizer.zero_grad()
# 2. forwardパスで損失を計算
logits, loss = model(x, y)
# 3. backwardパスで勾配を計算
loss.backward()
# 4. オプティマイザで重みを更新
optimizer.step()
print(f"step {i}: loss {loss.item()}")
# 出力例:
# step 0: loss 10.9993...
# step 1: loss 10.5099...
# ...
# step 49: loss 3.7565...
ステップが進むにつれて損失が着実に減少していることがわかります。モデルは、この小さなテキストの断片に対して、次の単語を予測する方法を学習し始めています。
3. DataLoaderの実装
しかし、上記の方法では常に同じテキストの断片で学習しているため、モデルはその部分にだけ過剰適合(overfitting)してしまいます。大規模なデータセット全体を効率的に学習するためには、データをミニバッチに分割して供給するDataLoaderが必要です。
ここでは、巨大なトークン配列からバッチを切り出すシンプルなDataLoaderLite
を実装します。
class DataLoaderLite:
def __init__(self, B, T, file_path='input.txt'):
self.B = B
self.T = T
with open(file_path, 'r') as f:
text = f.read()
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f"loaded {len(self.tokens)} tokens")
print(f"1 epoch = {len(self.tokens) // (B * T)} batches")
# 現在位置
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
# トークン配列からバッチを切り出す
buf = self.tokens[self.current_position : self.current_position + B*T + 1]
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)
# 位置を更新
self.current_position += B*T
# データセットの終端に来たら、先頭に戻る
if self.current_position + (B*T + 1) >= len(self.tokens):
self.current_position = 0
return x, y
このDataLoaderを使えば、学習ループはデータセット全体を順次処理していくようになります。
実験と検証
作成したDataLoaderLite
を学習ループに組み込んでみましょう。
train_loader = DataLoaderLite(B=4, T=32)
for i in range(50):
optimizer.zero_grad()
# データローダーから次のバッチを取得
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
logits, loss = model(x, y)
loss.backward()
optimizer.step()
print(f"step {i}: loss {loss.item()}")
# 出力例:
# step 0: loss 11.0090...
# step 1: loss 9.82779...
# ...
# step 49: loss 6.6667...
これにより、モデルは学習の各ステップで異なるテキストの断片を見ることになり、より汎用的な言語パターンを学習することができます。
よくあるミス
- 過学習 (Overfitting): 小さなデータセットで学習を続けると、モデルはそのデータに特化しすぎてしまい、未知のデータに対して性能が著しく低下します。対策として、大規模なデータセットを使う、正則化を行う、早期終了(Early Stopping)するなどの方法があります。
-
学習率 (Learning Rate):
lr
は学習の進み具合を決める最も重要なハイパーパラメータの一つです。大きすぎると学習が発散し、小さすぎると学習が全く進みません。適切な値を見つける必要があり、次の回で登場する学習率スケジューラを使うのが一般的です。
まとめ
今回は、モデルを学習させるための基本的な要素(損失計算、オプティマイザ、データローダ)を実装し、簡単な学習ループを完成させました。これにより、モデルはテキストデータを学習し、損失を下げることができるようになりました。
しかし、現在の実装はまだ初歩的です。GPT-2のような大規模モデルを安定して効率的に学習させるためには、さらに多くの高度なテクニックが必要です。
次回は、「高度な学習テクニック編」として、重み共有、適切な重み初期化、混合精度学習(autocast
)、torch.compile
による高速化、Flash Attentionなど、実践的な最適化手法を導入していきます。
(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)