こんにちは!逆瀬川 ( https://x.com/gyakuse ) です!
このアドベントカレンダーでは生成AIのアプリケーションを実際に作り、どのように作ればいいのか、ということをわかりやすく書いていければと思います。アプリケーションだけではなく、プロダクト開発に必要なモデルの調査方法、training方法、基礎知識等にも触れていければと思います。
0. 今回の記事について
今日は言語モデルの推論部分について軽くお話します。
それではやっていきましょう
1. そもそも: LLMってどうやって動いてる?
ハルシネーションについて考える前に、ぱぱっとLLMの挙動について思い出していきましょう。
LLMは膨大なテキストデータを使ってトレーニングされ、「次に来るであろうトークン (サブワード単位の単語)」を予測する仕組みになっています。ChatGPTやLlamaなどのモデルは、Transformerのdecoder部分のみを用いたアーキテクチャをベースに、与えられたテキストの続きを自己回帰的に1トークンずつ生成します。
ここでは、Llama3 の実装を軽くみていきます。推論に関するコードはllama/generation.pyにあります。generation.pyには text_completion
と chat_completion
(テキスト生成とチャット生成の関数) がありますが、これらは基本的に同じです。チャット生成専用モデル(Llama-3.1-8B-Instructなど)はInstructionデータで学習され、[INST]
や [SYS]
などの特殊タグを入れる必要がある程度です。ちなみに [INST]
や [SYS]
は 特殊トークン ではありません (たぶん、間違ってたら教えて下さい)。
アーキテクチャについて
Attention is all you need の図はもう100億回くらい見たと思うので省略しますが、Llamaはdecoderのみのアーキテクチャで、自己回帰的にテキストを生成します。学習時は「前のテキスト(コンテキスト)」から次のトークンを予測してパラメータを更新し、確率分布を大量に覚えさせられます。マジでえらい。推論についてはわかりやすくまとめると以下のような流れになります。
推論の流れ
入力テキストをトークン化
↓
モデル内部で次トークンの確率分布を計算
↓
デコード戦略を使ってトークンを決定
↓
これを出力するトークン分行う
2. 推論の実装を見ていこう
では、実装を見ていきましょう。
2-1. 入力テキストをトークン化する
モデルはテキストをそのまま扱えません。そこで、文章をトークンと呼ばれるIDのリストへ変換します。この処理をトークナイズといいます (厳密にはtokenize + encode)
Llama3では以下でトークン化が行われています
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
Llama2まではSentencePieceでしたが、Llama3はBPEベースのTiktokenを採用しています。
例えば「こんにちは」をトークン化すると ['こんに', 'ち', 'わ']
のようなサブワード列になり (厳密にはバイトベースですが)、内部の語彙と1:1対応しIDになります。
2-2. 次トークンの予測処理呼び出し
Transformerブロックの詳細等は今回は省きます。
logits = self.model.forward(tokens, prev_pos)
2-3. デコード
シンプルにtop-pを用いたサンプリングをしています。
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
3. デコード戦略について
- 貪欲法 (Greedy Decoding)
- 毎回、最も確率が高いトークンをそのまま選ぶ
- 毎回同じ出力になってしまっておもんない
- あと繰り返しもよく起きる
- サンプリング法 (特に Pure Sampling):
- 確率に従ってランダムにトークンを選ぶ
- 文脈崩壊しやすいみたいなとこある
- top-k サンプリング
- 上位k個のトークンに絞ってからサンプリング
- k=3なら、上位3候補からサンプリング
- やばいトークンを排除できて便利
- top-p サンプリング
- Llama3のオフィシャル実装で使われてたやつ
- 確率累積で一定割合pになるまで候補を取り、その範囲内でサンプリング
- 低確率のトークンは自然に消え去ってくれて便利
- temperature
- 確率分布を平坦化したり尖らせたりするパラメータ. softmaxのとこのTのやつ
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
- 確率分布を平坦化したり尖らせたりするパラメータ. softmaxのとこのTのやつ
例: 織田信長は美少女です
例えば、以下のテキストがプロンプトにあるとき、
たとえば織田信長が美少女だというゲーム・アニメを中心にトレーニングしたら織田信長が美少女だと思ってしまう。織田信長は美少女
モデルは『織田信長は美少女』の次にくるトークンの確率分布を推論します。
順位 | 予測 | 確率 |
---|---|---|
1 | だ | 0.2617 |
2 | だった | 0.0850 |
3 | である | 0.0752 |
4 | ではない | 0.0623 |
5 | な | 0.0457 |
6 | 。 | 0.0334 |
7 | で | 0.0334 |
8 | の | 0.0215 |
9 | 。 | 0.0203 |
10 | だから | 0.0178 |
上位候補を眺めると、「だ」「だった」「である」「ではない」あたりが高確率で、文法的にも自然です。top-pやtop-kを使えば、これらの候補から何を選ぶか制限できます。
低確率のトークン (「だから」など) は登場しにくくなりますが、temperatureを上げて、top-pをゆるくすれば『たとえば織田信長が美少女だというゲーム・アニメを中心にトレーニングしたら織田信長が美少女だと思ってしまう。織田信長は美少女だからね!』みたいな文章もできます。文芸作品等ではこうした意外性は必要なので重宝されます。Mirostat Sampling を用いると、一貫性や意外性をコントロールすることができます。
まとめ
- ほんとはhallucinationの実験まで行う予定だったのに力つきました
- 織田信長は美少女です