1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

langchainとDatabricksで(私が)学ぶRAG : 半構造化データに対するRAG②

Last updated at Posted at 2023-12-03

導入

以下の続きです。

シリーズ一覧はこちら。

前回はPDFファイルを読み込み、表のサマリを作るなどの加工を行ってから特徴量テーブルの保管するところまで実践しました。
今回はそのデータを使ってRetrieverを作成し、RAGを実行するChainを作成します。

Step 6. Embedding Model

ここからはRAG処理の主要部分になります。

まずはベクトル検索用にEmbedding(埋め込み)モデルをロードします。
以下のモデルをダウンロードしたものを利用しました。

import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings


device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_path = "/Volumes/training/llm/model_snapshots/models--intfloat--multilingual-e5-large"

embedding = HuggingFaceEmbeddings(
    model_name=embedding_path,
    model_kwargs={"device": device},
)

Step 7. MultiVectorRetriever

今回のポイントです。

ベクトル検索のためのストアを作成するために、前回作成したデータを特徴量テーブルから取得し、一部加工します。

from databricks.feature_engineering import FeatureEngineeringClient
from langchain.schema.document import Document

fe = FeatureEngineeringClient()
feature_name = "training.llm.sample_doc_features3"

df = fe.read_table(name=feature_name)

# table_sumary列をリスト型に変換した上で、chunk列と合成
df = df.withColumn(
    "chunk",
    F.when(F.col("chunk").isNotNull(), F.col("chunk")).otherwise(
        F.split(F.col("table_summary"), ",")
    ),
)

# ユニークなIDを振りなおす
df = (
    df.withColumn("chunk", F.explode("chunk"))
    .withColumn("index", F.monotonically_increasing_id())
    .withColumn("element_id", F.concat_ws("-", "element_id", "index"))
)

加工したデータを使って、「検索に使うベクトルストアを作成するためのDocumentリスト」と「LLMに与えるデータを格納するためのストア用データリスト」を作成します。

# 検索に使うベクトルストアを作成するためのデータ
pdf1 = df.select("element_id", "chunk").toPandas()

# LLMに与えるデータを格納するためのストア用データ
pdf2 = (
    df.withColumn(
        "data",
        F.when(F.col("type") == F.lit("CompositeElement"), F.col("chunk")).otherwise(
            F.col("text")
        ),
    )
    .select("element_id", "data")
    .distinct()
    .toPandas()
)

id_key = "element_id"
# ベクトルストアを作成するためのデータをDocumentのリストに変換
summary_chunks = [
    Document(page_content=row["chunk"], metadata={id_key: row[id_key]})
    for index, row in pdf1.iterrows()
]

# LLMに与えるデータをIDとデータのタプルリストに変換
summary_raw_data = [(row[id_key], row["data"]) for index, row in pdf2.iterrows()]

作成したデータを使ってベクトルストアとInMemorySotreの2種を作成し、それを使ってMultiVectorRetrieverを作成します。

from langchain.storage import InMemoryStore
from langchain.vectorstores import FAISS
from langchain.retrievers.multi_vector import MultiVectorRetriever

# The vectorstore to use to index the child chunks
vectorstore = FAISS.from_documents(summary_chunks, embedding)

# The storage layer for the parent documents
store = InMemoryStore() 

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

retriever.docstore.mset(summary_raw_data)

Step 8. Chain

ようやく準備が整いましたので、RAGのChainを作成しましょう。

まずはLLMを読み込みます。
前回同様、以下のモデルを事前にダウンロードして利用しました。

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat_3.5-AWQ"

generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

RAGを実行するためのChainを作成します。
プロンプトテンプレートはLangchain Templatesのものをそのまま流用しました。
それ以外は通常のRAGのChainと同様です。

from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

retriever = vectorstore.as_retriever()

template = """Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template(template),
        AIMessagePromptTemplate.from_template(""),
    ]
)

chat_model = ChatHuggingFaceModel(
    generator=generator,
    tokenizer=tokenizer,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
    ai_message_template="GPT4 Correct Assistant: {}",
    repetition_penalty=1.2,
    temperature=0.1,
    max_new_tokens=1024,
)

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | chat_model
    | StrOutputParser()
)

実行。

表の内容を伴わない質問をしてみます。

for s in chain.stream("この契約において知的財産権はどのような扱いなのか?"):
    print(s, end="", flush=True)
出力
この契約において、知的財産権は、知的財産基本法第2条第2項に定義された知的財産権、知的財産権を受ける権利、ノウハウその他の秘密情報を含むという意味で扱われています。また、乙が既に所有又は管理していた知的財産権(乙知的財産権)を乙が納入物に使用した場合、甲は当該乙知的財産権を、仕様書記載の「目的」のため、仕様書の「納入物」の項に記載した利用方法に従って、本契約終了後も期間の制限なく、また追加の対価を支払うことなしに自ら使用し、又は第三者に使用させることができます。

表に含まれることも聞いてみます。

for s in chain.stream("印刷工程について説明してください"):
    print(s, end="", flush=True)
このテキストは環境に配慮した印刷工程に関する要件をまとめたものです主な要件はデジタル化率の高い工程銀回収印刷版の再利用VOC対策熱風乾燥印刷のVOC処理装置の設置損紙のリサイクル率の高い工程省エネルギー活動騒音振動抑制表面加工の環境に配慮した製本加工の高いリサイクル率を行うことです

できてる・・・のか?

まとめ

というわけで、半構造化データを対象としたMult Vetor RetrieverによるRAGでした。
データの処理が非常に大変だったという所感です。

また、もう少しわかりやすい結果になるPDFを使えばよかったなと反省しています。

画像も検索対象にすると広範なユースケースに使えると思います。
ただし、その場合はLlavaのような画像を取り扱えるマルチモーダルモデルが必要になるため、このシリーズでやるかは不明。いつかはチャレンジしてみたいですが。

次回は少し手軽なRAG実践とかしたいと思っています。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?