LoginSignup
12
15

ChunkLlamaによる追加学習なしのLLMコンテキスト拡張を試す

Last updated at Posted at 2024-06-01

概要

LLMのコンテキスト長は、基本的にモデル学習時の系列長の長さに制限されます。これに対し、モデルの元々のコンテキスト長よりも大きなコンテキスト長を実現する技術がいわゆるコンテキスト拡張です。
多くのコンテキスト拡張の手法では長い系列長のデータを使った追加の学習が必要となります。必要な学習量の大小には差がありますが、そもそも学習のための機器の準備や設定、データセットの用意など一般ユーザにとってはハードルが高いものになります。
これに対し、ChunkLlamaという手法では追加の学習を必要とせずコンテキスト拡張を実現します。この手法をMistral-7bベースのモデルに対して適用し、推論や様々なテストを試しました。

目次

  1. ChunkLlamaについて
  2. Mistralベースのモデルで試す
  3. PPL(Perplexity)を測ってみる
  4. Passkey Retrieval Testをしてみる
  5. まとめ

ChunkLlamaについて

ChunkLlamaは、Dual Chunk Attention(DCA)という機構の導入により、追加の学習なしでコンテキスト長の拡張を可能にする手法です。
Dual Chunk Attention(DCA)は、入力シーケンスをある長さのchunkに分割し、それらに対して以下3つのアテンションメカニズムを適用することで長いコンテキストの情報を失わないまま効率的に処理する方法です。

  1. Intra-Chunk Attention: 同じチャンク内のトークン間の相対的な位置情報を処理するためのメカニズムで、これによって長い入力シーケンスに対するpplを低く保つことを可能にするもの。
  2. Inter-Chunk Attention: 異なるチャンクのトークン間の相対的な位置情報を処理するためのメカニズムで、これによって長い入力シーケンスに対する情報の保持を可能にするもの。
  3. Successive-Chunk Attention: 連続するチャンク間でのトークンの関係性を詳細にとらえるためメカニズムで、Inter-Chunk Attentionの特殊なケースのようなもの。

より詳細な内容については元論文をご確認ください。
また、github上に実装のコードも上がっています。現在はLlama2、Llama3、Mistral、Mixtral、Qwenへの実装例が上がっています。
image.png
(ChunkLlamaと他のコンテキスト拡張技術のPPLベースの比較)

image.png
(intra/inter/successive attentionの3つの手法によるAblation結果。それぞれの特徴が現れていて面白い)

Mistralベースのモデルで試す

今回は私が以前作成したMistralベースのモデルであるAratako/Ninja-v1-RP-expressiveでChunkLlama(MistralなのでChunkMistral)を試してみたいと思います。

ChunkLlama(ChunkMistral)の適用方法

適用方法はgithubのREADMEに載っており、非常に簡単です。
以下のように、replace_with_chunkmistral関数をimportして、モデルロード前にその関数を呼び出すという2行の追加で適用することが出来ます。

from transformers import AutoTokenizer, AutoModelForCausalLM
from chunkllama_attn_replace import replace_with_chunkmistral

# replace_with_chunkmistral関数を実行
# pretraining_lengthにはそのモデルの元のコンテキストサイズを指定。今回は4096
replace_with_chunkmistral(pretraining_length=4096)

# その後、通常通りmodelとtokenizerをロード
tokenizer = AutoTokenizer.from_pretrained("Aratako/Ninja-v1-RP-expressive", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Aratako/Ninja-v1-RP-expressive", attn_implementation="flash_attention_2", trust_remote_code=True, torch_dtype=torch.bfloat16)

# この後通常通りに推論

現状提供されているコードはHuggingface Transformers向けの実装のみなのでそれ以外の推論エンジンに適用するには別途カスタマイズが必要ですが、バックエンドがHuggingface Transformersのものには簡単に実装出来ます。
例えばoobabooga/text-generation-webuiにChunkLlamaを適用したい場合、modules/models.pyの中のhuggingface_loader関数内に上記のreplace_with_chunk**の呼び出しの1文を追加するだけで適用可能です。

実際に推論してみる

ChunkMistralを適用したモデルと適用していないモデルに、長いコンテキストの入力からの回答生成を行わせてみます。今回は魔法少女まどか☆マギカのWikipedia記事を使います。
Wikipediaの該当ページの先頭から「作品テーマについて」セクションまで内容、合計11315トークンを入力します。

プロンプトは以下のようなものです。

<s> あなたは誠実で優秀な日本人のアシスタントです。

USER: 以下の情報を元に、『魔法少女まどか☆マギカ』という作品についてその特徴などを含め簡潔に説明してください。

{Wikipediaの文章}
ASSISTANT: 
  • ChunkMistralなしの出力
    4kまでの系列長にしか対応していないため、完全に出力が崩壊しています。

■■■■■■■■■■■■■■■ ■■■■■■■■■■■■■■ *■■■■■■■■■■■■■■■■■■ ■ ■■■■■■ ■■■■ ■■ ■■■ ■■■■■■■■■■■■ ⦳ à ■■■■■■■ ■■■■■■■■■■■■■■ ■■■ ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■  ■■■ ■■■■■■■■■■■■■■■■■■■■■■■ ⼓㛺でより■■ ■■■■ ■■■■■■■■■■■■■ ■

■■■■■■■■■■■■■ ■■■■■■■■■■■■■■■■■■■■■■■■■■■ ■■■■■ ■■■■■■■■■ ■■■■■■ ■■■■■■■■■■■■■■■■ ■  ■■■■■■ ■ ■■ ■■■■ ■■■■■■■■■■■■■■■■■■ ■■■■■ ■■■■■■ ■■■■■■■■■■■■■■■■■■■■■■■■■ ■■■■■■■■■ ■■■■■■ ■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ ■■■■■■ ■■■■■■■■■■■■■ ■■■

  • ChunkMistralありの出力
    長い入力に対しても問題なく適当な回答を生成できています。

この作品は、魔法少女というジャンルに属しながら、魔法少女の約束事に新たな解釈を持ち込んだものであり、魔法少女たちが直面する難題や、彼女たちの成長、そして世界の真実についての探求が描かれる。また、キャラクターの心理や感情、そしてそれぞれの立場や関係性が複雑に絡み合うことによって、単なる魔法少女ものではなく、深い人間ドラマとしての側面も持っている。本作品は、そうした複雑な要素を持つ物語を通して、視聴者に対して多くのことを語りかけているのである。

PPL(Perplexity)を測ってみる

PPL(Perplexity)はモデルが次に来る単語をどれだけうまく予測できるかを示すような指標です。これが高い場合、与えたテキストに対して上手く予測が出来ていないことを示唆します。モデルが扱える系列長を超えた場合、一般的にPPLは大きく上昇します。
githubで配布されているサンプルコードを元に、系列長を増やしながらPPLを計測してみます。
リポジトリ内のpplフォルダにサンプルコードがあるので、これを元に日本語テキストのtokenizeとそれに対するPPL計測をします。今回はWikiText-JAのTest_Data_F.txtを利用しました。
test_ppl.pyを以下のように編集して簡易的に実行しています。

  • 読み込み部分でLlama2ではなくAratako/Ninja-v1-RP-expressiveを読み込むように変更
  • 対象テキストの全部分ではなく最初の100回のみ計測するように変更

結果は以下の通りになりました。

PPLの測定結果

Sequence Length 4096 8192 16384 32768 65536
PPL (without chunkmistral) 3.681 7.938 23.687 49.859 84.068
PPL (with chunkmistral) 3.679 3.532 3.394 3.251 3.332

65k tokens程度の非常に長い系列長に対しても低いPPLを維持できていることが分かります。

Passkey Retrieval Testをしてみる

PPLの計測で長いコンテキストに対しても次単語の予測精度が安定しているという事の示唆は得られましたが、これだけでは本当にその長いコンテキストの内容を記憶しているか・それを取り出せるかは分かりません。
これを評価するためのテストがPasskey Retrieval Testです。これは、非常に長い無駄なテキストの中のランダムな場所に特定のテキスト(passkey)を埋め込み、それを抽出させるような指示をしてpasskeyを抽出できるかをテストするものです。
これが出来た場合、単に長いコンテキストを処理しているだけでなく、その長いコンテキストの中の過去の特定の情報を記憶し上手く取り出せていることを示唆します。
こちらもリポジトリ内のpasskeyフォルダにサンプルコードがあるので、これを元にテストを行います。テストの際はモデルロード部分だけ変更しました。

結果は以下の通りになりました。

Passkey Retrieval Testの結果

Sequence Length 1982 3894 7990 16182 32566 65334
acc (without chunkmistral) 1.0 1.0 0.0 0.0 0.0 0.0
acc (with chunkmistral) 1.0 1.0 1.0 1.0 1.0 0.98

65k tokens程度の非常に長い系列長の入力に対しても上手く情報を取り出せていることが分かります。

まとめ

今回はChunkLlamaによるコンテキスト長の拡張を試しました。結果はかなり良好で、追加学習なしで実現可能であることも考えるとかなり有用に感じます。
残念ながら現状実装例があるのはHuggingface Transformersを使ったもののみですが、issueをみるとvLLMへの統合も進められているようなので、今後そのような推論エンジンにも実装されるとより使いやすくなると思います。

12
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
15