概要
nanochatはTransformerを作った人のうちの一人である、Karpathyさんが作った、Transformer (GPT) の事前学習、事後学習、評価、モデル出力(hugginface)、lossなどの記録(wandb)の全てをコマンド一つで行うことが出来るリポジトリである
条件分岐などをなるべく用いておらず、コードが読みやすい
また、最小限のモデルサイズで簡単な会話が出来るようにするために、全体的に新しい技術を導入していて、デフォルトのGPT2とは異なるものになっている
リポジトリは、8GPU (H100)を前提とするが、少しコードを変えて、モデルサイズを小さく(12層)すると1GPU (H100)でも出来た
実行方法
git cloneしてから、
bash speedrun.sh
環境としては、Linuxとuvさえ入っていればOK。仮想環境やuvのライブラリの設定はこのシェルスクリプト内で行ってくれる。
モデル
GPTをベースとするTransformerアーキテクチャ。層の数を変えることでスケーリング出来る。(デフォルトでは20層)
層の数に応じて、埋め込み次元なども変化するプログラムになっている。(層の数×64)
n_head !=n_kv_headならGQAに切り替わる。デフォルトではn_head==n_kv_headであり、シンプルなSelf-Attentionとなっている。
KV_cacheにも対応済み
12層では185Mパラメータであり、その内訳は以下である。
具体的なモデルのパラメータ内訳 (12層)
| Module Name | Class | Direct Params | Total Params |
|---|---|---|---|
[model] (root) |
OptimizedModule |
0 | 185,597,952 |
- _orig_mod |
GPT |
0 | 185,597,952 |
- transformer |
ModuleDict |
0 | 135,266,304 |
- wte |
Embedding |
50,331,648 | 50,331,648 |
- h |
ModuleList |
0 | 84,934,656 |
- [0-11] (x12) |
Block |
0 | 7,077,888 (each) |
- attn |
CausalSelfAttention |
0 | (2,359,296 each) |
- mlp |
MLP |
0 | (4,718,592 each) |
- lm_head |
Linear |
50,331,648 | 50,331,648 |
【補足】
-
transformer.h(ModuleList) は、12個のBlockを保持しています。 - 各
Blockの総パラメータ数は7,077,888です。 - したがって、
hの総パラメータ数は7,077,888 * 12 = 84,934,656となり、表のTotal Paramsと一致します。 -
Block内部のattnやmlpの詳細も、12個のブロックすべてで共通であるため、代表として1つ分の内訳を()書きで示しました。
事前学習
学習データセット
fineweb-edu-100b (100M rows)
fineweb-eduは良く事前学習で使われる。そのままだと多すぎるからkarpathyさんが小さくしたと思われる。
optimizer
AdamWとMuonの複合
Muonはよりlossが下がりやすい最新のオプティマイザ。
tokenizer
RushBPEとTiktokenの複合
Tiktokenは良く使われるトークナイザーで、RushBPEはトークナイザーの学習を早くするプログラム。トークナイザーの学習をすると、より良いトークン分割が出来る。
Positional Embedding
Rotary Positional Embedding (RoPE)
絶対位置埋め込みと相対位置埋め込みの組み合わせ
相対位置埋め込みを入れることでモデルが学習しやすくなり、性能が上がるらしい。
中盤学習
LLMで中盤学習(mid-train)はあまり聞かないが、この中盤学習では簡単な会話データを学習する。基本的な設定は事前学習と一緒
学習データセット
karpathyさん独自の会話データセット
事後学習
後々行う様々な評価タスクのための学習をする。
以下の評価タスクにおける学習用データを学習する。
学習データセット一覧 (計23K rows)
| 概要 | データ数 | |
|---|---|---|
| ARC-Easy | AIの推論能力を測るベンチマーク。人間には簡単だがAIには難しい(簡単な方) | 2.3k rows |
| ARC-Challenge | 上記ARCの、より難易度の高い質問セット。 | 1.1k rows |
| GSM8K | 小学校レベルの算数の文章問題 | 8K rows |
| SmolTalk | 会話形式のデータセット | 10K rows |
| CustomJSON | JSON形式データをどれとけ正確に解析・処理できるかを評価するためのデータセット。 | 1K rows |
| SimpleSpelling | 基本的なスペル修正能力を評価するためのデータセット | 300 rows |
| SpellingBee | 単語の正確なスペリング(綴り)をどれだけ理解しているかを評価するためのデータセット。 | 300 rows |
事前学習評価
事前学習の評価は22個の様々なタスクで行われる。CORE metricはその平均スコアである。
そのタスクの一覧と、d12 (12層、185M)のモデルのスコアは以下である。
事前学習結果 (12層)
| Metric | Score |
|---|---|
| CORE metric | 0.1379 |
| hellaswag_zeroshot | 0.1031 |
| jeopardy | 0.0142 |
| bigbench_qa_wikidata | 0.3788 |
| arc_easy | 0.3866 |
| arc_challenge | 0.0091 |
| copa | 0.2400 |
| commonsense_qa | 0.1145 |
| piqa | 0.2688 |
| openbook_qa | 0.0907 |
| lambada_openai | 0.2787 |
| hellaswag | 0.1056 |
| winograd | 0.1941 |
| winogrande | 0.0150 |
| bigbench_dyck_languages | 0.0690 |
| agi_eval_lsat_ar | 0.0707 |
| bigbench_cs_algorithms | 0.3962 |
| bigbench_operators | 0.1190 |
| bigbench_repeat_copy_logic | 0.0000 |
| squad | 0.0409 |
| coqa | 0.0992 |
| boolq | -0.1387 |
| bigbench_language_identification | 0.1779 |
事後学習評価
事後学習の評価は6つのタスクに対して行う。そのデータセットの一部のデータは事後学習で用いている。
以下は、その6つのタスクとそのスコアである。BASEは事前学習後、MIDは中盤学習後、SFTは事後学習後のスコアである。
| Metric | Stage | d12 (185M) | d20 (560M) | d32 (1.8b) |
|---|---|---|---|---|
| CORE | BASE | 0.1379 | 0.2219 | 0.3168 |
| ARC-Easy | MID | 0.3262 | 0.3561 | 0.6233 |
| SFT | 0.3228 | 0.3876 | 0.6797 | |
| ARC-Challenge | MID | 0.2901 | 0.2875 | 0.4787 |
| SFT | 0.2765 | 0.2807 | 0.4991 | |
| MMLU | MID | 0.3006 | 0.3111 | 0.3896 |
| SFT | 0.3000 | 0.3151 | 0.4049 | |
| GSM8K | MID | 0.0152 | 0.0250 | 0.1099 |
| SFT | 0.0190 | 0.0455 | 0.1274 | |
| HumanEval | MID | 0.0671 | 0.0671 | 0.1098 |
| SFT | 0.0488 | 0.0854 | 0.1280 | |
| ChatCORE | MID | 0.2083 | 0.0730 | 0.2417 |
| SFT | 0.2053 | 0.0884 | 0.2734 |
d12は10/31日時点のリポジトリを自分で学習を行った結果であるが、d20, d32はkarpathyさんがgithubに挙げている学習結果から取得したものである。したがって、リポジトリのバージョンが違うため学習手法やデータセットが異なる可能性があることは注意