はじめに
2023月9月初旬に、Llama2ベースで128kトークンが利用できるYarn-Llama-2-128kモデル(以下、Yarn-llama2)が発表されました。OpenAIのAPI以外で、エンベディングによく利用されるMultilingual-e5は、512トークンが上限となっており、論文などの長文を丸ごとエンベディングをしようと思うと、上限トークン数をオーバーすることになります。
本記事では、128Kトークンが利用できるYarn-llama2で、論文のエンベディングができるのか試してみます。
環境
- Google Colab Pro A100
エンベディングする論文
今回はテスト用の論文として、Yarn-llama2モデルが関係するYaRNの論文(以下、Yarn論文)、Llama2の論文(以下、Llama2論文)、そして系統が全く異なる論文として、2014年の記事で、最も引用された論文と紹介されていた溶液中のタンパク質の量を決定する分析法である、Oliver Lowry氏の"PROTEIN MEASUREMENT WITH THE FOLIN PHENOL REAGENT"(以下、Lowry論文)、を選定しました。エンベディングした各論文のベクトルで、類似度を判定してみます。期待としては、Yarnの論文を基準として、類似度が高い順に、Llama2論文、Lowry論文、となることです。
とりあえずYarn-llama2を使ってみる
とりあえずYarn-llama2を使って、論文の要約を試してみます。必要なライブラリをpipでインストールします。PDFを読み込むために、pdfminer.sixを利用します。
!pip install transformers tokenizers accelerate sentencepiece
!pip install pdfminer.six
!pip install flash-attn --no-build-isolation
!pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary
!pip install pdfminer.six
Yarn-llama2は、Rotary Position Embeddingsが特徴であり、効率的にテキストウィンドウを拡張することができるとされています。また、公開されたモデルには、attentionにFlash Attention2という、ハードウェアの特徴を生かしてメモリアクセスを改善し、計算を高速化する技術も使用されています。利用にするためにはAmpere以上のGPUが求められ、今回はColab ProでA100を利用して実施します。先人の記述に従い、Yarn論文を要約してもらいます。
from transformers import AutoTokenizer
import transformers
import torch
from pdfminer.high_level import extract_text
model_name = "NousResearch/Yarn-Llama-2-7b-128k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipeline = transformers.pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
documents = extract_text("Yarn.pdf")
text = documents.replace("\n","")
print(len(text))
question="I am going to summarize the academic contribution of this paper in the following statement."
sequences = pipeline(
f"I am going to read the following academic paper. \n\n {text} \n\n {question}\n",
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=20000,
)
回答
I am going to summarize the academic contribution of this paper in the following statement.
The paper introduces a new method for context window extension of transformer-based LLMs, which can be used to extend the context window size of LLMs without fine-tuning.
The paper also introduces a new method for fine-tuning the context window size of LLMs.
The paper also introduces a new method for extrapolating the context window size of LLMs.
一応要約らしいことをしようとしているようです。huggingfaceのmodel cardを見ると、this is a pretrained base modelと書かれているので、適切なデータセットで、ファインチューニングすれば、改善されると思います。
論文をエンベディングする
AutoModelでのパラメータ読み込み
Yarn-llama2を使ってエンベディングをします。 今回は、エンベディングされたベクトルに、AutoModel
で推論した時のlast_hidden_state
を利用します。デフォルトのままAutoModel
で、パラメータをロードするとエラーが出るため、config
にAutoModel
で読み込みたいクラス名の記述を追加して、AutoModel.from_pretrained
でパラメータをロードします。
from transformers import AutoModel, AutoConfig
model_name = "NousResearch/Yarn-Llama-2-7b-128k"
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.auto_map['AutoModel'] = "NousResearch/Yarn-Llama-2-7b-128k--modeling_llama_together_yarn.LlamaModel"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_base = AutoModel.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
論文のエンベディング
今回利用した論文のトークン数は
- Yarn論文: 11940
- Llama2論文: 79583
- Lowry論文: 10898
です。残念ながら、Llama2論文の約80kトークンは、Google Colab A100のメモリに乗らず、tokenizerの最大トークン数は、15kにしています。この時点で、Yarn-llama2を利用する意味のほとんどは失っているのですが、このまま続けます。それぞれの論文はPDFであり、pdfminer.sixを使用して本文を読み込みます。
from pdfminer.high_level import extract_text
yarn_text = extract_text("Yarn.pdf").replace("\n","")
llama2_text = extract_text("Llama2.pdf").replace("\n", "")
lowry_text = extract_text("Lowry.pdf").replace("\n", "")
改行を削除する程度のクリーニングはしています。不要な文章を削除するなどの処理はしていません。
エンベディングのコードは、multilingua-e5のUsageを参考にしています。エンベディングに使用するlast_hidden_state
変数は、[バッチ数, トークン数, 隠れ層数]
という次元になっており、トークン数は、文献毎に異なるため、トークン数の次元で平均します。その結果、各論文は4096次元のベクトルにエンベディングされます。エンベディングした論文同士の類似度は、コサイン類似度で判定します。数値の大きいもの同士が、類似度が大きいと判定されます。
import torch.nn.functional as F
from torch import Tensor
import numpy as np
#ref:https://huggingface.co/intfloat/multilingual-e5-large
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
embeddings = []
for text in [yarn_text, llama2_text, lowly_text]:
with torch.no_grad():
token = tokenizer(text,max_length=15000,return_tensors='pt')
output = model_base(**token.to(model_base.device))
embedding_temp = average_pool(output.last_hidden_state, token['attention_mask'])
embedding_temp = F.normalize(embedding_temp, p=2, dim=1)
embeddings.append(embedding_temp.float().cpu().numpy()[0])
embeddings = np.array(embeddings)
score = (embeddings @ embeddings.T) * 100 # コサイン類似度
print(score)
Yarn論文と、ほかの2つの論文のコサイン類似度は以下になります。
論文 | 類似度 |
---|---|
Llama2 | 91 |
Lowry | 57 |
期待通り、Yarn論文とLlama2論文の類似度が大きいと判定されました。
おわりに
128kトークンが利用可能であるYarn-llama2-128kですが、今回は残念ながら128kを試すことができず、最大15kのトークン数でのエンベディングになりました。128kトークンで入力しようとすると、マルチGPUで実施する必要があるのでしょう。
エンベディングした各論文は、Yarn論文とLlama2論文は意味的に近く、Lowry論文は遠いと判断されました。Llama2論文は、Yarn論文の引用文献なので、当たり前の結果ですが、今回の論文の中では、一応類似度は妥当に判定できているようです。ただ、huggingfaceにあるエンベディングベンチマークのリーダーボードを見ると、トップは2023年9月時点でbert系であり、transformer decoder系のYarn-llama2を利用した、この記事のエンベディング方法に意味があるかは、もう少し数をこなして考える必要があると思います。