5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

langchainとDatabricksで(私が)学ぶRAG : ColBERTによるReranking

Posted at

導入

私が学ぶRAGの実質10回目です。二桁分もネタが続くとは思っていませんでした。
シリーズ一覧はこちら
今回はColBERTによるRerankです。

これは何?

RAGにおいて、Retrieverが取得してきた複数の内容を、クエリに対する関連度(or 回答の妥当性)の高い順番に並び替えることをRerankingと言います。
例えば、セマンティック検索の結果とキーワード検索の結果の両方を取得し、それら全体で関連度の高いもの上位3つを抜き出すことで、関連度の高い文書だけをLLMに渡すような際に用います。

ColBERTはBERTベースの高速で正確な検索モデルです。
ColBERTのようなモデルを容易に使えるようにするためのライブラリとしてRAGatouilleというものがあり、今回はlangchainとRAGatouilleを使ってRerankingを実践してみます。

主として、以下の内容のウォークスルーに近いものとなります。

※ RAGatouille自体はColBERTのようなモデルを使ってRAGを簡単に実現する仕組なので、Rerankだけでなく通常のRetrieverとして利用することもできるようです。

検証はDatabricks on AWS上で実施しました。
DatabricksのDBRは14.1 ML、GPUクラスタ(g4dn.xlarge)上で動作を確認しています。

Step0. パッケージインストール

使うパッケージをインストールします。
今回はRAGatouilleとColBERT(今回はJaColBERT)を利用するために必要なパッケージを追加で入れています。
なお、AutoAWQはCUDA 11.8版をWheelを指定してインストールしています。
(DatabricksのDBR ML 14台はCUDAのバージョンが11.8のため)

%pip install -U -qq transformers accelerate langchain faiss-cpu ragatouille fugashi unidic_lite 
# CUDA 11.8用
%pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp310-cp310-linux_x86_64.whl
%pip install "databricks-feature-engineering"

dbutils.library.restartPython()

Step1. Document Loading

準備編で保管した特徴量を取得します。
読み込んだデータは、docsという名前のビューから参照できるようにします。

from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()

feature_name = "training.llm.sample_doc_features"
df = fe.read_table(name=feature_name)

df.createOrReplaceTempView("docs")

Step2. Splitting

ごくごく単純なText Splitterを使って長文データをチャンキングします。
ベーシックなRAGのときと同様です。

from typing import Any
import pandas as pd
from pyspark.sql.functions import pandas_udf
from langchain.text_splitter import RecursiveCharacterTextSplitter


class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
    """句読点も句切り文字に含めるようにするためのスプリッタ"""

    def __init__(self, **kwargs: Any):
        separators = ["\n\n", "\n", "", "", " ", ""]
        super().__init__(separators=separators, **kwargs)


@pandas_udf("array<string>")
def split_text(texts: pd.Series) -> pd.Series:

    # 適当なサイズとオーバーラップでチャンク分割する
    text_splitter = JapaneseCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
    return texts.map(lambda x: text_splitter.split_text(x))


# チャンキング
df = spark.table("docs")
df = df.withColumn("chunk", split_text("page_content"))

# Pandas Dataframeに変換し、チャンクのリストデータを取得
pdf = df.select("chunk").toPandas()
texts = list(pdf["chunk"][0])

Step3. Storage

チャンクデータに埋め込み(Embedding)を行い、ベクトルストアへデータを保管します。

まずはEmbedding用のモデルをロード。
以下のモデルをダウンロードしたものを利用します。

import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_path = "/Volumes/training/llm/model_snapshots/models--intfloat--multilingual-e5-large"

embedding = HuggingFaceEmbeddings(
    model_name=embedding_path,
    model_kwargs={"device": device},
)

FAISSでベクトルストアを作成します。

vectorstore = FAISS.from_texts(texts, embedding)

Step4. ColBERT(JaColBERT) Model Loading

ここから今回のポイント。

ColBERTモデルをロードします。

日本語用のColBERTモデルJaColBERTが以下に公開されていますので、こちらを利用します。

JaColBERTは日本語の埋め込み・検索性能において、multilingual-e5-baseと同程度の性能のようです。
詳しくは上記リンク先を確認ください。また、下記の記事でも解説・実験されていました。

では、RAGatouilleを使ってロード処理を実行。

from ragatouille import RAGPretrainedModel

RAG = RAGPretrainedModel.from_pretrained("bclavie/JaColBERT")

# ColBERT(v2)オリジナルはこちら。
# RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")

RAGに組み込む前に、JaColBERTを使ったRerankingを試してみます。

Step3.で作成したvectorstoreからベースとなるRetrieverを取得し、そこからJaColBERTで関連度を再計算させ、上位3件の結果を取得するようにしてみます。

from langchain.retrievers import ContextualCompressionRetriever

# FAISS vectorstoreから類似文書を10件取得するRetriever
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})

# 上記のRetrieverの結果から、JaColBERTで類似度計算を行い、上位3件を取得するretrieverを作成
compression_retriever = ContextualCompressionRetriever(
    base_compressor=RAG.as_langchain_document_compressor(k=3),
    base_retriever=base_retriever,
)

# テスト実行
compressed_docs = compression_retriever.get_relevant_documents(
    "この契約において知的財産権はどのような扱いなのか?"
)

compressed_docs
出力
[Document(page_content='しに、納入物の利用 に必要な範囲で、前項の第三者の知的財産権を 自由かつ対価の追加\n支払なしに 使用し、又は第三者に使用させることができる。  \n4 委託業務の遂行中に納入物に関して乙(甲の同意を得て一部を再委託する場合は再委\n託先を含む。) が新たに知的財産権(以下 「新規知的財産権」という。)を取得した場\n合には、乙は、その詳細を書面にしたものを納入物に添付して甲に提出するものとする。', metadata={'relevance_score': 21.140625}),
 Document(page_content='り本契約が終了した後で あっても、なおその効力を有する。  \n \n(著作権等の帰属)  \n第28条  納入物に係る 著作権(著 作権法第2 7条及び第28条の権利を含む。ただし、\n本契約締結日現 在、乙、乙以外 の委託事業参加 者又は第三者の権利 対象となっているも\nのを除く。以下同じ。)は、委 託金額以外の追加支払なしに、その発生と同時に乙から', metadata={'relevance_score': 20.25}),
 Document(page_content='新規知的財産権は 約定の委託金額以外の追加支払なしに、納入物の引渡しと同時に乙か\nら甲に譲渡され、甲単独に帰属する。  \n5 前項の規定にかかわらず、著作権等について は第28条の定めに従う。  \n6 乙は、本契約終了後であっても、知的財産権の取 扱いに関する本契約 の約定を 自ら遵\n守し、及び第7条第1項 の再委託先に遵守させ ることを 約束する。', metadata={'relevance_score': 19.359375})]

内容の妥当性はさておき、JaColBERTを使った関連度上位3つを取得できました。
また、metadata内にrelevance_scoreが追加されています。

では、これを使ってRAGのパイプラインを組んでみましょう。

Step5. LLM Preparation

RAGのパイプラインを構成するLLMをロードします。
以下のモデルを事前にダウンロードして利用しました。

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-0106-AWQ"

generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

Step6. Chain preparation

RAGのChainを作成します。
今回は比較のために、単純にRetrieverを使うケースと、Rerankするケースの2種類のChainを作成します。

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import (
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

# コンテキストを使って単純に質問回答させるプロンプトテンプレート
prompt_template = """Answer the question based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template(prompt_template),
        AIMessagePromptTemplate.from_template(""),
    ]
)

chat_model = ChatHuggingFaceModel(
    generator=generator,
    tokenizer=tokenizer,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
    ai_message_template="GPT4 Correct Assistant: {}",
    repetition_penalty=1.2,
    temperature=0.1,
    max_new_tokens=400,
)

# Chainその1。通常のRetrieverを使ったRAG。関連度の高い5件を取得・利用。
simple_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
chain1 = (
    {"context": simple_retriever, "question": RunnablePassthrough()} 
    | prompt 
    | chat_model 
    | StrOutputParser()
)

# Chainその2。ColBERTによるRerankを使ったRAG。関連度の高い15件を取得した後、ColBERTで再計算した関連度上位5件を利用。
rerank_retriever = ContextualCompressionRetriever(
    base_retriever=vectorstore.as_retriever(search_kwargs={"k": 15}),
    base_compressor=RAG.as_langchain_document_compressor(k=5), 
)

chain2 = (
    {"context": rerank_retriever, "question": RunnablePassthrough()} 
    | prompt 
    | chat_model 
    | StrOutputParser()
)

これでRAGの準備完了です。

では動作比較してみましょう。

Step7. Run

まずは、Rerankを使わないRAGのChainを実行。

chain1.invoke("この契約において知的財産権はどのような扱いなのか?")
出力(見やすさのために改行を追加してます)
'この契約において、知的財産権は、乙が納入物に第三者の知的財産権を利用する場合には、
第1条第2項の規定に従い、乙の費用及び責任において当該第三者から本契約の履行及び本契約終了後の甲による知的財産権を、
仕様書記載の「目的」のため、仕様書の「納入物」の項に記載した利用方法に従い、本契約終了後であっても、
知的財産権の取扱いに関する本契約の約定を自ら遵守し、及び第7条第1項の再委託先に遵守させることを約束する。
また、新規知的財産権は約定の委託金額以外の追加支払なしに、納入物の引渡しと同時に乙から甲に譲渡され、
甲単独に帰属する。'

次にRerankを使ったRAGのChainを実行。

chain2.invoke("この契約において知的財産権はどのような扱いなのか?")
出力(見やすさのために改行を追加してます)
'この契約において、知的財産権は、委託業務及び納入物に関して、約定の委託金額以外の支払義務を負わない。
本契約終了後の納入物の利用についても同様とする。
委託金額には委託業務の遂行に必要な範囲で、前項の第三者の知的財産権を自由かつ対価の追加支払なしに使用し、
又は第三者に使用させることができる。
新規知的財産権は、約定の委託金額以外の追加支払なしに、納入物の引渡しと同時に乙から甲に譲渡され、
甲単独に帰属する。'

正直、どちらがより適切なのかは、この元文章をしっかり理解しているわけではないので判断できないのですが、Rerankの使用有無によって異なる内容が生成されました。

まとめ

RAGatouilleを使ったColBERTによるRerankingを実践してみました。
ColBERTの性能面は正直まだわからないところがあるのですが、multilingual-e5よりも軽量・高速に動作し、かつ精度も十分である可能性があります。
ローカルでの埋め込み・インデックス生成の選択肢としてよいかもしれません。

また冒頭にも記載したように、RAGatouilleはRerankerだけでなく、通常のRAG(インデックス生成)を簡単に実装できるようなので、継続的に触っていってみようと思います。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?