0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Andrej Karpathy氏が公開したMicroGPTを読み解いてみる

0
Posted at

はじめに

先日OpenAIの創設者のひとり Andrej Karpathy氏よりmicrogpt.py( https://gist.github.com/karpathy/8627fe009c40f57531cb18360106ce95 )
が公開されました。これはChatGPTのようなLLMアルゴリズムを外部の高度なライブラリを使わず、純粋なPythonのみで構築されています。
すごくわかりやすく(といっても難しいですが)コード化されていると思ったのでこれを自分の理解のためにも備忘録化して残します。
※"GPTの仕組みを直観的に理解すること"に焦点を置いているため、全てを解説できているわけではありませんのでご了承ください。また不正確な説明があるかもしれませんがコメント等でご指摘いただけますと幸いです。

前提知識・こんな人に読んでほしい

  • Pythonコードはある程度読める
  • 細かい数式は苦手だけどGPTの仕組みについて知りたい
  • コードベースで確認してみたい

まずは全体像

MicroGPTの動きを簡単にまとめると
「次のトークン(文字)を予測 → 間違いを学習し、正解を導き出す」を繰り返しています。

さらに分解すると
「文字のトークン化(Tokenize) → 埋め込み(Embedding) → 文脈理解(Attention) → 変換・思考(MLP) → 予測(Predict) → 正解とのズレ(Loss) → 逆伝搬による学習(Backprop)」を何度も繰り返し学習。学習済モデルを用いて1文字ずつ正解の文字を生成しています。

ここからは各操作の仕組みについて該当箇所のコードを示しながら、1つ1つ詳しく見ていきます。
※Pythonコードのコメントも併用し解説していきます。

文字のトークン化(Tokenize)

  • 該当コード
    # データ準備
    docs = [line.strip() for line in open('input.txt') if line.strip()]
    random.shuffle(docs)
    # トークン化(文字→ID)
    uchars = sorted(set(''.join(docs)))
    BOS = len(uchars)
    vocab_size = len(uchars) + 1
    
  • 説明
    ここで言語をAIが計算できる「数値」に変換しています。最もシンプルなトークナイザーは、一意に文字に整数のIDを割り当てています。BOSはBeginning of Sequenceの略でドキュメントの開始と終了を示すマーカーです。
    例) "a" → 0, "b" → 1, …

  • 該当コード
    num_steps = 1000 # トレーニングの回数
    for step in range(num_steps):
    
        doc = docs[step % len(docs)]
        tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
    
  • 説明
    ここで文字に割り振ったIDをもとに文字単位でトークナイズしています。
    例)
    "anna" → [BOS, a, n, n, a, BOS] → [26, 0, 13, 13, 0, 26]
    ※アルファベット小文字のみ(26文字)

埋め込み(Embedding)

  • 該当コード
    tok_emb = state_dict['wte'][token_id]
    pos_emb = state_dict['wpe'][pos_id]
    x = [t + p for t, p in zip(tok_emb, pos_emb)]
    
  • 埋め込みではトークン化されたIDに対して、意味と位置情報を持つベクトル(tok_emb, pos_emb)に変換しています。単なるIDのままでは情報を持たないため計算ができないためです。
    さらに意味と位置を示すベクトルを加算して(x)入力データを生成します。

意味ベクトル(tok_emb)は"a"と"b"の違いを区別できるようにし、位置(pos_emb)では文字の並び ("a","b" と "b","a")を区別できるようにするイメージ

文脈理解(Attention)

AttentionはGPTの心臓部といっても過言ではありませんので、しっかり見ていきます。

  • 該当コード
        for li in range(n_layer):
            # 残差接続のため一時保存
            x_residual = x
            # rmsnormで入力を安定させる。
            x = rmsnorm(x)
            # xからQuery(q), Key(k), Value(v)を生成
            q = linear(x, state_dict[f'layer{li}.attn_wq'])
            k = linear(x, state_dict[f'layer{li}.attn_wk'])
            v = linear(x, state_dict[f'layer{li}.attn_wv'])
            # これまでの全トークンを記憶
            keys[li].append(k)
            values[li].append(v)
            x_attn = []
            for h in range(n_head):
                hs = h * head_dim
                # ベクトルを分割して複数視点で処理(Multi-head)
                q_h = q[hs:hs+head_dim]
                k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
                v_h = [vi[hs:hs+head_dim] for vi in values[li]]
                # qと過去のkの類似度
                attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
                # 重みに変換
                attn_weights = softmax(attn_logits)
                # 情報の集約
                head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
                # 結合
                x_attn.extend(head_out)
            x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
            # 残差接続
            x = [a + b for a, b in zip(x, x_residual)]
    
  • 説明
    Attentionとは一言で言うと現在の単語が、過去のどの単語を見るべきかを決める仕組みです。
    Attentionの流れとしては
    x(現在の単語) → Query(q), Key(k), Value(v)の3つのベクトルを生成 → qと過去のkの類似度を計算 → softmax関数で重み化(生の数値を確立分布て意味付け) → vを重み付き平均し出力しています。
    Query(q), Key(k), Value(v)の下りが少しわかりずらいのでかみ砕くと
    Queryは何を探しているか(現在のトークンが求める鍵穴)
    Keyは自分が何を持っているか(過去のトークンが持つ鍵)
    Valueは実際の情報(鍵が開いたときに取り出せる宝箱)
    のような役割があります。
    ここでqとkの類似度(内積)を計算し、値が大きく似ている(現在の鍵穴に最も合う過去の鍵)を「重要な情報」として判断します。
    1つだけ選ぶのではなく、すべての過去トークンに対して類似度を計算し、
    それをsoftmaxによって確率(重み)に変換します。つまり、よく合う鍵 → 重みが大きい(よく参照する)あまり合わない鍵 → 重みが小さい(ほぼ無視する)
    という状態になります。その後、この重みを使ってValue(宝箱の中身)を重み付きで合成します。
    「完全に一致する鍵だけを使う」のではなく、「いくつかの鍵を“どれくらい合っているか”の割合で混ぜて使う」という感じです。
    最終的に「現在のトークンにとって最も適切な文脈情報」が生成されます。

このAttentionの処理が、GPTが「文脈を理解しているように見える」正体です。

  • 残差接続:元の情報を消さずに残しておく仕組み。xにattention処理をしたものに、元のxを加算している。
    • 該当コード
      x_residual = x
      # ~~attention~~
      x = [a + b for a, b in zip(x, x_residual)]
    
  • RmsNorm:モデルを安定させるためにベクトルの大きさを整える処理(正規化)。該当コードはrmsnorm関数の定義参照
  • Softmax:指数関数を用いて0~1の値を出力し、出力値の合計を1(100%)とする関数。該当コードはsoftmax関数の定義参照

変換・思考(MLP)

次にMLPの解説です。MLP(多層パーセプトロン)はAttentionで集めた情報を「解釈・変換」する部分です。

  • 該当コード
    # 前述のattentionの出力を残差接続のため一時保存
    x_residual = x
    # rmsnormで入力を安定させる。
    x = rmsnorm(x)
    # === ここからがMLP ===
    # 次元拡張(16次元→64次元)
    x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
    # 非線形変換
    x = [xi.relu() for xi in x]
    # 次元圧縮(64次元→16次元)
    x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
    # === ここまでがMLP ===
    # 残差接続
    x = [a + b for a, b in zip(x, x_residual)]
    
  • 説明
    全体の流れとしては x(Attentionの出力)→ 線形変換(拡張)→ ReLU(非線形)→ 線形変換(圧縮)→ 残差接続 となっています。
    かみくだくと次元を拡張し、表現力UP(情報を細かく分解)しています。そしてReLU(非線形)では情報の取捨選択(いらない情報を消す)。そして次元を元の状態に戻し、新しい意味として出力しています。これだけでは直感的に理解しずらいため、例を示します。
    例)
    「I am not happy」という文章を考えます。
    Attentionの役割は「not」という打ち消しの言葉が「happy」にかかっていることを見つけ出し、情報を集めることです。一方で、MLPは「not」と「happy」という情報を組み合わせて「これは悲しい(ネガティブ)である」と解釈します。もしMLPがないと否定の意味として処理できず、「happyという単語があるからポジティブである」と解釈してしまいます。

MLPによって単に文字をつなげるだけでなく、文脈を読み解く「知能」のような働きを実現しています。

ReLU:マイナスの数値を全て0にし、0以上の数値をそのまま出力する関数

ここまででgptの頭脳が読み解けました、ここからはgptの学習を解説していきます。

学習(予測 → 誤差計算 → 誤差逆伝搬 → パラメータ更新のループ)

  • 該当コード
    num_steps = 1000 #学習ループの回数
    for step in range(num_steps):
        # データ取り出し+トークン化
        doc = docs[step % len(docs)]
        tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
        # 長さ調整
        n = min(block_size, len(tokens) - 1)
    
        keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
        losses = []
    
        # 順番に予測
        for pos_id in range(n):
            token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
            # gptモデルに通し、予測
            logits = gpt(token_id, pos_id, keys, values)
            probs = softmax(logits)
            # 損失
            loss_t = -probs[target_id].log()
            losses.append(loss_t)
        loss = (1 / n) * sum(losses)
        # 逆伝搬
        loss.backward()
        # 学習率減衰
        lr_t = learning_rate * (1 - step / num_steps)
        # パラメータ更新
        for i, p in enumerate(params):
            m[i] = beta1 * m[i] + (1 - beta1) * p.grad
            v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
            m_hat = m[i] / (1 - beta1 ** (step + 1))
            v_hat = v[i] / (1 - beta2 ** (step + 1))
            p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
            p.grad = 0
    
        print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}", end='\r')
    
  • 説明
    ここでは学習(予測 → 間違い探し → 修正の繰り返し)を実施しています。
    token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
    
    → 次に来るコードを予測
    logits = gpt(token_id, pos_id, keys, values)
    probs = softmax(logits)
    
    → 現在の文脈から「次の文字の確率」を計算
    loss_t = -probs[target_id].log()
    
    → 間違いを数値化(損失)
    ※lossが小さいほど正解である確率が高く、lossが大きいほど正解の確率が低い
    loss = (1 / n) * sum(losses)
    
    → 文章全体でどれくらい間違えたかを計算
    loss.backward()
    
    → 誤差逆伝搬(損失をもとに、どこが悪かったを全パラメータに対して計算)
    p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
    
    → 間違いを減らす方向に修正
    lr_t = learning_rate * (1 - step / num_steps)
    
    → 学習率減衰(学習の後半になるほど慎重に学習するように調整)

MicroGPTではこのような学習を1000回繰り返しています。
ここまでの説明でようやく学習フェーズが終了しました。最後に学習済モデルを用いた文章生成です。

文章の生成

学習済モデルを用いて生成を実施しています。このコードではありそうな名前を20個生成しています。

temperature = 0.5
print("\n--- inference (new, hallucinated names) ---")
for sample_idx in range(20):
    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
    token_id = BOS
    sample = []
    for pos_id in range(block_size):
        logits = gpt(token_id, pos_id, keys, values)
        probs = softmax([l / temperature for l in logits])
        token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0]
        if token_id == BOS:
            break
        sample.append(uchars[token_id])
    print(f"sample {sample_idx+1:2d}: {''.join(sample)}")

temperature: 0~1の間で操作可能。低いと毎回同じような文字を生成し、高いとランダム性が増しより創造的になる。

これで解説終了です。

おわりに

今回はMicroGPTを読み解いてみました。長文になってしまいましたがここまで見ていただきありがとうございます。実際のChatGPTはもっと複雑で、膨大なコード量ですが、このようなピュアなモデルを真面目に見ることで世の中の生成AIの基本原理が理解でき、新たな発見や技術的に正しい判断につながるのではないかと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?