5
1

LLMのRAGにSparkを活用してみる

Posted at

実験メモのような内容です。
無駄に長い。。。

導入

LLMの利用において、外部知識をモデルに与えるためには、SFTなどのFine tuningする方法や、プロンプトとして必要な情報を与える方法があります。
後者の方法として、Retrieval Augmented Generation (RAG) はかなり一般化してきているように思います。

※ RAGについては下記リンク先をどうぞ。

最近のLLMはコンテキスト長も大幅に増えてきており、RAGはもっと発展していく流れになると考えています。

個人的に困っているのはVectorStoreのデータ管理。
様々なVectoreStoreが公開されていますが、既存のデータ管理システムとは別個にテキストデータを管理することになるので、ちょっと面倒です。
特にDatabricksはUnity Catalogなど非常に優秀なマネジメント機構が備わっていますし、Sparkなどの分散処理の仕組が強力です。

というわけで、なるべくSparkの仕組を活かしつつRAGを実装する方法を模索してみます。
近いうちにLakehouse AIが実装されて不要になりそうな気もしますが、試しということで。

検証はDatabricks on AWSで行っています。使用したクラスタはg4dn.xlargeのシングルノード、DBRは13.3 LTS MLです。

今回やること

  • LLMにおける長期記憶の領域をSpark(というかDelta Lake)側で管理する
  • セマンティック検索を担う部分は、従来のVectorDB(今回はFAISS)で行う
  • セマンティック検索で得られた結果(チャンクデータ)をそのまま使うのではなく、プロンプトに埋め込むデータは長期記憶のデータを引っ張ってくるようにする

というわけでやってみよう。

Step1. 準備

今回は以下のdolly-15k-jaを利用します。
通常はモデルの学習用に使われるデータだと思いますが、これをRAGの中で検索するデータとして使います。

まず、必要になりそうなモジュールをインストールします。
※ 最終的に使わないものがいくつか出ましたが、一旦残しておきます。

%pip install -U -qq transformers accelerate ctranslate2 langchain faiss-cpu sentencepiece fugashi unidic-lite einops SentencePiece bitsandbytes

dbutils.library.restartPython()

いくつかのモジュールと定数を定義しておきます。
今回は、Embedding用モデルとして https://huggingface.co/intfloat/multilingual-e5-large を使用します。また、事前にUnityCatalogのボリュームに保管しています。

データはtrainingというカタログ・llmというスキーマを作ってそこに保管することにしました。
中間テーブルをいくつか作るので、そのためのテーブル名を定義しておきます。

import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf
import pyspark.sql.types as T
from pyspark.sql import DataFrame

from typing import Any, Iterator
from functools import reduce

import pandas as pd

from langchain.embeddings import HuggingFaceEmbeddings


EMBEDDING_MODEL_NAME = "/Volumes/埋め込み用モデルのパス/multilingual_e5_large"

RAW_TABLE = "training.llm.databricks_dolly_15k_ja"
CHUNKED_TABLE = "training.llm.databricks_dolly_15k_ja_chunked"
EMBEDD_DOCUMENT_TABLE = "training.llm.databricks_dolly_15k_ja_embedd_documents"
LONGTERM_MEMORY_TABLE = "training.llm.databricks_dolly_15k_ja_longterm_memory"

最後にHugginface Hubからデータセットをダウンロードして、保管します。

dataset_id = "kunishou/databricks-dolly-15k-ja"
from datasets import load_dataset

def save_dataset(dataset_id:str) -> DataFrame:
    dataset = load_dataset(dataset_id)
    df = spark.createDataFrame(dataset["train"])

    df.write.mode("overwrite").saveAsTable(RAW_TABLE)

    return df

save_dataset(dataset_id)

こんな感じのデータがテーブルに保管されます。

image.png

Step2. データのチャンク化と埋め込み(Embedding)

  • 長文データをチャンク化
  • 各チャンクに対する埋め込み(Embedding)の作成

を行います。

チャンクの作成

langchainのtext_splitterを作って、200文字を目途に分割します。
JapaneseCharacterTextSplitterというカスタムSplitterを作って、句読点なども句切り文字になるようにしています。

あとはpandas_udfを使って、Spark DataFrameのテーブルデータを上記クラスでチャンク化します。
dolly-15k-jaデータは、文字情報としてinputinstructionoutputの3列あります。
これらを全て検索対象とすることにし、全てチャンク化後、最終的に一つの列内にチャンクの配列としてデータを保持させています。

from langchain.text_splitter import RecursiveCharacterTextSplitter

class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
    """日本語用のTextSplitter。句読点も句切り文字に含める
    参考:https://www.sato-susumu.com/entry/2023/04/30/131338
    """

    def __init__(self, **kwargs: Any):
        separators = ["\n\n", "\n", "", "", " ", ""]
        super().__init__(separators=separators, **kwargs)

@pandas_udf("array<string>")
def split_as_chunk(texts:pd.Series) -> pd.Series:

    chunk_size = 200
    chunk_overlap = 10

    # トークン化ユーティリティのインスタンスを作成
    text_splitter = JapaneseCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap
    )

    return texts.map(text_splitter.split_text)

def create_chunk_data(table:str) -> DataFrame:

    # チャンクデータ列の作成
    df = spark.table(table)

    df = df.withColumn("_input_chunk", split_as_chunk("input"))
    df = df.withColumn("_inst_chunk", split_as_chunk("instruction"))
    df = df.withColumn("_output_chunk", split_as_chunk("output"))

    df = df.withColumn("chunks", F.concat("_input_chunk", "_inst_chunk", "_output_chunk"))
    df = df.drop("_input_chunk", "_inst_chunk", "_output_chunk")

    df.write.mode("overwrite").saveAsTable(CHUNKED_TABLE)

    return spark.table(CHUNKED_TABLE)

create_chunk_data(RAW_TABLE)

こんな感じのチャンク配列データがデータフレーム内に追加されます。

image.png

埋め込み

埋め込み用モデルを使って、先ほどチャンク化したデータを基にEmbeddingを作成します。
※ 今回はpandas_udfを使って行っていますが、一度Pandasのデータフレームに変換した後にEmbedding処理した方が速いと思います。
この方式だと埋め込み生成に20分以上かかりました。

# 一度に処理するレコードのMAXサイズを設定。GPUのメモリとの兼ね合いを考慮して、ここの制御は注意。
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")

@pandas_udf("array<array<double>>")
def embed_array(chunks: pd.Series) -> pd.Series:

    import torch
    device = "cuda" if torch.cuda.is_available() else "cpu"
    emb = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        model_kwargs={"device":device},
    )
    return chunks.map(lambda c: emb.embed_documents(c))

def create_embedding(table:str) -> DataFrame:

    df = spark.table(table)
    df = df.withColumn("embeddings", embed_array("chunks"))

    df.write.mode("overwrite").option("overwriteSchema", True).saveAsTable(EMBEDD_DOCUMENT_TABLE)

    return spark.table(EMBEDD_DOCUMENT_TABLE)

df = create_embedding(CHUNKED_TABLE)
display(df)

こんな感じのベクトルデータの配列カラムが追加されます。

image.png

最後に、embedding_pairsというチャンクとそのEmbeddingをペアにした列を作ります。
これでデータ加工は終了です。

def create_longterm_memory(table:str) -> DataFrame:

    df = spark.table(table)
    df = df.withColumn("embedding_pairs", F.arrays_zip("chunks", "embeddings"))

    df.write.mode("overwrite").option("overwriteSchema", True).saveAsTable(LONGTERM_MEMORY_TABLE)

    return spark.table(LONGTERM_MEMORY_TABLE)

df = create_longterm_memory(EMBEDD_DOCUMENT_TABLE)
display(df)

追加した列はこんな感じ。

image.png

このように、検索する対象のテキストデータとそのチャンク、ベクトルデータを同一のテーブル内で管理します。以後、検索対象データが増える際には、このテーブルにデータを追加していくことで、全体データをSpark(Deltalake)内で管理します。

ただし、このままではセマンティック検索ができません。
この後、必要なデータを切り出してVectorStoreに格納・セマンティック検索をする流れでRAGを実行していきます。

Step3. Spark DataFrameからRetrieverを作成

  • Spark DataFrame -> VectorStore -> langchainのRetriever作成

を行います。

まず、langchainのBaseRetrieverを継承したカスタムRetrieverクラスを作成します。
これはSpark DataFrameを基にFAISSのVectorStoreにベクトルデータを格納し、それをlangchainのRetrieverとして使えるようにするクラスです。
※ かなり手抜きなので、async対応までしていません。注意。

また、get_relevant_documentsをする際に、VectorStoreのチャンクをそのままDocumentで返すのではなく、大本のDataFrameのデータ(context列)を参照しなおすようにしています。
これによって、細切れのチャンクデータをプロンプトに入れるのではなく、全文を挿入するなどの対応を柔軟にできるようになります。
※ このような方法について、何かの論文・記事に記載があったと思うのですが場所を失念していしまいました。。。

from typing import Any, Iterator, List

from langchain.schema import BaseRetriever
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.embeddings.base import Embeddings
from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import Field


class SparkRetriever(BaseRetriever):
    """Sparkを活用したVectorStoreRetrieverの拡張"""

    retriever: VectorStoreRetriever = None
    search_kwargs: dict = Field(default_factory=dict)
    search_type: str = "similarity"

    data: DataFrame
    """ 長期記憶を保持しいているSpark DataFrame """

    embeddings: Embeddings
    """ 埋め込みを生成するモデル """

    index_col: str = "index"
    """ DataFrameにおけるindex(id)列名 """

    embeddings_pairs_col: str = "embedding_pairs"
    """ DataFrameにおけるchunkとembedding情報を含んだ列名 """

    context_col: str = "context"
    """ Promptに入れるためのコンテキスト情報を含んだ列名 """

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    def __init__(
        self,
        data,
        embeddings,
        search_type="similarity",
        search_kwargs=Field(default_factory=dict),
    ):
        super().__init__(
            data=data,
            embeddings=embeddings,
            search_type=search_type,
            search_kwargs=search_kwargs,
        )
        self.retriever = self._rebuild_retriever(
            df=data,
            embeddings=embeddings,
            search_type=search_type,
            search_kwargs=search_kwargs,
            index_col=self.index_col,
            embeddings_pairs_col=self.embeddings_pairs_col,
        )

    def _rebuild_retriever(
        self,
        df: DataFrame,
        embeddings: Embeddings,
        search_type: str,
        search_kwargs,
        index_col: str = "index",
        embeddings_pairs_col: str = "embedding_pairs",
    ):
        """指定されたDataframeからFAISSでVectorStoreRetriverを再ビルドする"""

        df = df.select(
            index_col, F.explode(embeddings_pairs_col).alias(embeddings_pairs_col)
        )
        pdf = df.toPandas()

        metadata = pdf[[index_col]].to_dict(orient="records")
        text_embedding_pairs = list(
            pdf[embeddings_pairs_col].map(lambda x: (x["chunks"], x["embeddings"]))
        )

        faiss = FAISS.from_embeddings(
            text_embeddings=text_embedding_pairs,
            metadatas=metadata,
            embedding=embeddings,
        )

        return faiss.as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs,
        )

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:

        docs = []
        if self.retriever:
            self.retriever.search_kwargs = self.search_kwargs
            docs = self.retriever.get_relevant_documents(query)
            indices_df = spark.createDataFrame(
                set([(d.metadata[self.index_col],) for d in docs]), [self.index_col]
            ).distinct()

            # 長期記憶データからコンテンツを取得
            df = self.data.join(indices_df, [self.index_col], "inner")
            df = df.select(self.index_col, self.context_col)
            pdf = df.toPandas()

            def create_doc(d):
                return Document(
                    page_content=d[self.context_col], metadata={"id": d[self.index_col]}
                )

            docs = list(pdf.apply(create_doc, axis=1))
        return docs

    def reset_dataframe(self, data: DataFrame):
        self.data = data
        self.retriever = self._rebuild_retriever(
            df=data,
            embeddings=self.embeddings,
            index_col=self.index_col,
            embeddings_pairs_col=self.embeddings_pairs_col,
        )

これの何がいいかというと、例えば「特定のカテゴリだけのRetrieverを準備する」のように参照範囲を容易に限定することができたり、適切なアクセス権があるデータのみ参照できるようになるなど、データマネジメントを行いやすくなります。

デメリットはクラス生成単位時にベクトルデータを作り直すので、データ量によっては時間がかかることでしょうか。

では、このクラスを使ってretrieverのインスタンスを作成します。
セマンティック検索の結果として得られる内容は、instruction列とoutput列を結合したものにしました。

import torch

def create_spark_retriever():

    device = "cuda" if torch.cuda.is_available() else "cpu"
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        model_kwargs={"device":device},
    )

    df = spark.table(LONGTERM_MEMORY_TABLE)
    df = df.withColumn("context", F.concat_ws("\n", F.col("instruction"), F.col("output")))    
    df = df.cache()

    return SparkRetriever(data=df, embeddings=embeddings, search_kwargs={"k":3})

retriever = create_spark_retriever()

Step4. RetrievalQAを実装

実際に作成したRetrieverを使って、RetrievalQAを実行してみましょう。
LLMは下のリンク先で作ったVicuna v1.5のCTranslate2変換モデルを利用します。
(そのためgenerator/tokenizerやCtranslate2StreamLLMの実装部分は割愛)

prompt_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: 文脈に含まれる内容を使って、質問に答えなさい。

### 質問:
{question}

### 文脈:
{context}
ASSISTANT: """
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

llm = Ctranslate2StreamLLM(generator=generator, tokenizer=tokenizer)

retriever.search_kwargs={"k":3}

chain_type_kwargs = {"prompt": PROMPT, "verbose":True}
retrieval_qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, chain_type_kwargs=chain_type_kwargs)

# 比較用
prompt_template2 = "USER: {question}\ASSISTANT: "
PROMPT2 = PromptTemplate(
    template=prompt_template2, input_variables=["question"]
)
non_retrieval_qa = LLMChain(
    llm=llm,
    prompt=PROMPT2
)

いくつか、dolly-15k-jaに含まれる内容について質問してみます。
違いが分かるように、Retrieverを使わないケースと合わせて聞いてみます。(1個目の回答がRetriever無し、2個目があり)

query = "ブドウの栽培方法は?"

print("No Retrive: ", non_retrieval_qa.predict(question=query))
print("--------------------------------")
print("Retrive: ", retrieval_qa.run(query))
print("--------------------------------")
output

No Retrive:  ブドウの栽培方法には、以下のようなものがあります。

1. 地中海性ブドウ(ヴェルメット):夏の日差しが強く、気候が乾燥する地域で育成されます。
2. 大陸性ブドウ(ピノ・ノワール、カベルネ・ソーヴィニョン、メルロー、シャルドネなど):比較的寒い地域で育成されます。
3. アイスワイン用ブドウ(リースリング、グリューラー、サウザン・クランプなど):冷涼な気候で育成されます。
4. オーガニック・バイオダイナミック栽培:化学肥料や農薬を使
--------------------------------
Retrive:  ブドウの栽培方法は、湿気の少ない砂地を好むため、水やりの後は必ず地面を乾燥させることが重要です。また、持続的に生産できる量よりも多くの果実を育てることがあるため、春先に積極的に剪定することがおすすめです。剪定の方法としては、昨年伸びた芽を10~15個ほど残したまま2~4本のつるを残し、今年の芽が伸びるようにジョイントを残しておきます。そして、この長さのつるをトレリスにワイヤリングして、秋にブドウを楽しむことができます。
--------------------------------

Retriveした方が、栽培方法について回答していますね。

query = "カップル・リトリート・ガーデンとは?"

print("No Retrive: ", non_retrieval_qa.predict(question=query))
print("--------------------------------")
print("Retrive: ", retrieval_qa.run(query))
print("--------------------------------")

output
No Retrive:  カップル・リトリート・ガーデン(Couple Retreat Garden)とは、カップル向けの静かでロマンティックな庭園を意味します。これらの庭園は、カップルが一緒に過ごすために設計されており、美しい景色や落葉樹などの自然要素を取り入れています。

カップル・リトリート・ガーデンには、通常、以下のような要素が含まれます:

1. 座敷やベンチ:カップルがゆっくりと話し合いや読書を楽しむことができる場所。
2. 花壇:色鮮やかな花々や香り高いハーブなどを育てている庭。
3. 小川や池
--------------------------------
Retrive:  カップル・リトリート・ガーデンは、中国の蘇州にある庭園で、1874年に造園された歴史があり、現在は世界遺産にも登録されています。この庭園は、東西に分かれた2つの部分からなっており、古典的な庭園としては珍しい構成を持っています。また、庭園内には多くの建物や道教の塔があり、運河から船でアクセスすることができます。

ガーデニングとは、花やハーブ、野菜などの植物を育てるために、敷地を整備し、手入れを行うことです。

デートのアイデアとしては、新し
--------------------------------

Retriveしたほうはdolly-15k-jaに含まれる内容で回答されています。

query = "テレビ番組「フレンズ」には誰が出演している?"

print("No Retrive: ", non_retrieval_qa.predict(question=query))
print("--------------------------------")
print("Retrive: ", retrieval_qa.run(query))
print("--------------------------------")

output
No Retrive:  申し訳ありませんが、私は2021年9月までの情報しか持っていないため、現在放送中のテレビ番組「フレンズ」についてはお答えできません。ただし、過去に放送された同名の番組に関しては、日本のお笑いコンビ・フレンドリー(村上純・濱口優)が出演していました。
--------------------------------
Retrive:  テレビ番組「フレンズ」には、ジェニファー・アニストン(Rachel Green)、コートニー・コックス(Monica Geller)、リサ・クドロー(Phoebe Buffay)、マット・ルブラン(Joey Tribiani)、マシュー・ペリー(Chandler Bing)が出演しています。
--------------------------------

これもRetriveしたほうはdolly-15k-jaに含まれる内容で回答されています。

まとめ

というわけで、Spark DataFrameに格納したデータから、チャンク化・埋め込み、Retrieverの作成までやってみました。
RAGは非常に重要なテクニックであり、またこれがDWHとして構築しているテーブルからそのままRetriverを生成できると管理上非常に便利だと思ったのでやってみました。

ざっと作った内容なので不備等あると思いますので、何かあれば指摘ください。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1