5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

最大128kトークン利用できるYarn-Llama-2 で論文をエンベディングしてみる

Last updated at Posted at 2023-09-17

はじめに

2023月9月初旬に、Llama2ベースで128kトークンが利用できるYarn-Llama-2-128kモデル(以下、Yarn-llama2)が発表されました。OpenAIのAPI以外で、エンベディングによく利用されるMultilingual-e5は、512トークンが上限となっており、論文などの長文を丸ごとエンベディングをしようと思うと、上限トークン数をオーバーすることになります。

本記事では、128Kトークンが利用できるYarn-llama2で、論文のエンベディングができるのか試してみます。

環境

  • Google Colab Pro A100

エンベディングする論文

今回はテスト用の論文として、Yarn-llama2モデルが関係するYaRNの論文(以下、Yarn論文)、Llama2の論文(以下、Llama2論文)、そして系統が全く異なる論文として、2014年の記事で、最も引用された論文と紹介されていた溶液中のタンパク質の量を決定する分析法である、Oliver Lowry氏の"PROTEIN MEASUREMENT WITH THE FOLIN PHENOL REAGENT"(以下、Lowry論文)、を選定しました。エンベディングした各論文のベクトルで、類似度を判定してみます。期待としては、Yarnの論文を基準として、類似度が高い順に、Llama2論文、Lowry論文、となることです。

とりあえずYarn-llama2を使ってみる

とりあえずYarn-llama2を使って、論文の要約を試してみます。必要なライブラリをpipでインストールします。PDFを読み込むために、pdfminer.sixを利用します。

!pip install transformers tokenizers accelerate sentencepiece
!pip install pdfminer.six
!pip install flash-attn --no-build-isolation
!pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary
!pip install pdfminer.six

Yarn-llama2は、Rotary Position Embeddingsが特徴であり、効率的にテキストウィンドウを拡張することができるとされています。また、公開されたモデルには、attentionにFlash Attention2という、ハードウェアの特徴を生かしてメモリアクセスを改善し、計算を高速化する技術も使用されています。利用にするためにはAmpere以上のGPUが求められ、今回はColab ProでA100を利用して実施します。先人の記述に従い、Yarn論文を要約してもらいます。

from transformers import AutoTokenizer
import transformers
import torch
from pdfminer.high_level import extract_text

model_name = "NousResearch/Yarn-Llama-2-7b-128k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipeline = transformers.pipeline(
    "text-generation",
    model=model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

documents = extract_text("Yarn.pdf")
text = documents.replace("\n","")
print(len(text))

question="I am going to summarize the academic contribution of this paper in the following statement."
sequences = pipeline(
    f"I am going to read the following academic paper. \n\n {text} \n\n {question}\n",
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_length=20000,
)

回答

I am going to summarize the academic contribution of this paper in the following statement.
 The paper introduces a new method for context window extension of transformer-based LLMs, which can be used to extend the context window size of LLMs without fine-tuning.
 The paper also introduces a new method for fine-tuning the context window size of LLMs.
 The paper also introduces a new method for extrapolating the context window size of LLMs.

一応要約らしいことをしようとしているようです。huggingfaceのmodel cardを見ると、this is a pretrained base modelと書かれているので、適切なデータセットで、ファインチューニングすれば、改善されると思います。

論文をエンベディングする

AutoModelでのパラメータ読み込み

Yarn-llama2を使ってエンベディングをします。 今回は、エンベディングされたベクトルに、AutoModelで推論した時のlast_hidden_stateを利用します。デフォルトのままAutoModelで、パラメータをロードするとエラーが出るため、configAutoModelで読み込みたいクラス名の記述を追加して、AutoModel.from_pretrainedでパラメータをロードします。

from transformers import AutoModel, AutoConfig

model_name = "NousResearch/Yarn-Llama-2-7b-128k"
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.auto_map['AutoModel'] = "NousResearch/Yarn-Llama-2-7b-128k--modeling_llama_together_yarn.LlamaModel"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_base = AutoModel.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)

論文のエンベディング

今回利用した論文のトークン数は

  • Yarn論文: 11940
  • Llama2論文: 79583
  • Lowry論文: 10898

です。残念ながら、Llama2論文の約80kトークンは、Google Colab A100のメモリに乗らず、tokenizerの最大トークン数は、15kにしています。この時点で、Yarn-llama2を利用する意味のほとんどは失っているのですが、このまま続けます。それぞれの論文はPDFであり、pdfminer.sixを使用して本文を読み込みます。

from pdfminer.high_level import extract_text
yarn_text = extract_text("Yarn.pdf").replace("\n","")
llama2_text = extract_text("Llama2.pdf").replace("\n", "")
lowry_text = extract_text("Lowry.pdf").replace("\n", "")

改行を削除する程度のクリーニングはしています。不要な文章を削除するなどの処理はしていません。

エンベディングのコードは、multilingua-e5のUsageを参考にしています。エンベディングに使用するlast_hidden_state変数は、[バッチ数, トークン数, 隠れ層数]という次元になっており、トークン数は、文献毎に異なるため、トークン数の次元で平均します。その結果、各論文は4096次元のベクトルにエンベディングされます。エンベディングした論文同士の類似度は、コサイン類似度で判定します。数値の大きいもの同士が、類似度が大きいと判定されます。

import torch.nn.functional as F
from torch import Tensor
import numpy as np

#ref:https://huggingface.co/intfloat/multilingual-e5-large
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

embeddings = []
for text in [yarn_text, llama2_text, lowly_text]:
	with torch.no_grad():
    	token = tokenizer(text,max_length=15000,return_tensors='pt')
    	output = model_base(**token.to(model_base.device))
    	embedding_temp = average_pool(output.last_hidden_state, token['attention_mask'])
    	embedding_temp = F.normalize(embedding_temp, p=2, dim=1)
    	embeddings.append(embedding_temp.float().cpu().numpy()[0])
embeddings = np.array(embeddings)
score = (embeddings @ embeddings.T) * 100 # コサイン類似度
print(score)

Yarn論文と、ほかの2つの論文のコサイン類似度は以下になります。

論文 類似度
Llama2 91
Lowry 57

期待通り、Yarn論文とLlama2論文の類似度が大きいと判定されました。

おわりに

128kトークンが利用可能であるYarn-llama2-128kですが、今回は残念ながら128kを試すことができず、最大15kのトークン数でのエンベディングになりました。128kトークンで入力しようとすると、マルチGPUで実施する必要があるのでしょう。

エンベディングした各論文は、Yarn論文とLlama2論文は意味的に近く、Lowry論文は遠いと判断されました。Llama2論文は、Yarn論文の引用文献なので、当たり前の結果ですが、今回の論文の中では、一応類似度は妥当に判定できているようです。ただ、huggingfaceにあるエンベディングベンチマークのリーダーボードを見ると、トップは2023年9月時点でbert系であり、transformer decoder系のYarn-llama2を利用した、この記事のエンベディング方法に意味があるかは、もう少し数をこなして考える必要があると思います。

参考

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?