15
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

langchainとDatabricksで(私が)学ぶRAG : シリーズ一覧 & 準備編

Last updated at Posted at 2023-11-05

シリーズ一覧

増えてきたので、ショートカットを入れておきます。


番外編

導入

LLMで様々な処理を実現するためのフレームワークとしてlangchainが人気ですが、langchainの公式DocやBlogでは様々なRAG(Retrieval Augmented Generation)について紹介されています。

例えば、これら。

また、先日Langchain Templateがアナウンスされました。

この中でも多くのRAGに関するテンプレートが公開されています。

LLMの活用において、RAGは重要な仕組です。そのため、RAGの改善・工夫については様々なテクニックが生まれており、改めてちゃんと身に着ける必要性を感じています。

というわけで、Databricks上で実践しながらRAG関連で学んだことを書いていこうと思います。一応シリーズとして、ゆっくりやっていく予定。

今回は、始める前の準備編です。

補足:基本方針

EmbeddingモデルのFine-tuningが必要となる改善は行わない、もしくは後回しにする予定です。
対象は主にRetrieverに関するものや、Queryの変換などを行うものを中心的にやっていこうと思っています。

準備

RAGを題材にするので、検索対象の文書などを今回は準備します。
今後、同じデータを流用していく予定です。

Step1. データの取得

まずは必要なモジュールをインストール。

%pip install langchain pypdf databricks-feature-engineering

dbutils.library.restartPython()

適当な長い文章を利用したいため、今回は「経済産業省の令和5年概算契約書」のPDFファイルを使用することにします。

以下のようにPDFファイルを取得し、各ページごとにテキストデータを分割取得します。

from langchain.document_loaders import PyPDFLoader

pdf_url = "https://www.meti.go.jp/information_2/downloadfiles/r5gaisan-2_format.pdf"
loader = PyPDFLoader(pdf_url)
pages = loader.load()

Step2. データの加工

不要な内容の除去など、前処理を実行します。
今回は余分な文字の除去と対象ページを限定します。
その上で、全てのページを一つの文章にまとめます。

import pandas as pd
import pyspark.sql.functions as F
import pyspark.sql.types as T

pdf = pd.DataFrame([[d.page_content, d.metadata] for d in pages], columns=["page_content", "metadata"])

df = spark.createDataFrame(pdf)

# 別紙類を除くため、15ページまでを使う
df = df.withColumn("page", F.col("metadata.page").cast(T.IntegerType()))
df = df.filter("page < 15")

# 各ページの先頭にR5-G-2と表記されるため、除去
df = df.withColumn("page_content", F.regexp_replace(F.col("page_content"), "R5-G-2", ""))

# 単一行の文字列データに変換
df = df.withColumn("doc_id", F.lit(1))
df = df.withColumn("doc_name", F.lit("委託契約書フォーマット(概算契約)"))
df = df.groupBy("doc_id", "doc_name").agg(F.collect_list("page_content").alias("page_content"))
df = df.withColumn("page_content", F.concat_ws("", "page_content"))
display(df)

df.createOrReplaceTempView("doc")

結果は以下のような1行だけのデータとなります。

image.png

Step3. データを特徴量として保管

加工済みデータをDatabricksの特徴量として保管します。
(通常のSparkテーブルや別の形で保管してももちろんOKです)
保管先のtraining.llmというスキーマは事前に作成済とします。

from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()

df = spark.table("doc")
feature_name = "training.llm.sample_doc_features"

doc_table = fe.create_table(
    name=feature_name,
    primary_keys="doc_id",
    schema=df.schema,
    description="Sample Document",
)

fe.write_table(name=feature_name, df=df, mode="merge")

保管が正しくできたか、読み出し。

df = fe.read_table(name=feature_name)

display(df)

image.png

できていますね。

Step4. langchain カスタムチャットモデルの準備

RAGで使うLLMはHuggingFace上のモデルをダウンロードして利用予定です。
そのために、以下のlangchainカスタムチャットモデルを利用します。

transformers_chat.py
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    TextIteratorStreamer,
)
from threading import Thread
import asyncio

from typing import (
    Any,
    List,
    Union,
    Mapping,
    Tuple,
    Optional,
    Iterator,
    AsyncIterator,
)
from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain.schema import (
    ChatGeneration,
    ChatResult,
)
from langchain.schema.output import ChatGenerationChunk

class ChatHuggingFaceModel(BaseChatModel):
    """HuggingFace transformers models.

    To use, you should have the ``transformers`` python package installed.

    Example:
        .. code-block:: python

            quantization_config = AwqConfig(
                bits=4,
                group_size=128,
                zero_point=True,
                version="gemm",
                backend="autoawq",
            )

            config = AutoConfig.from_pretrained(model_path)
            config.quantization_config = quantization_config.to_dict()
            model = AutoModelForCausalLM.from_pretrained(model_path, config=config, device_map="cuda")
            tokenizer = AutoTokenizer.from_pretrained(model_path)
    """

    generator: PreTrainedModel
    """ Transformers pretrained Model """
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
    """ Tokenizer """

    # メッセージテンプレート
    human_message_template: str = "USER: {}\n"
    ai_message_template: str = "ASSISTANT: {}"
    system_message_template: str = "{}"

    do_sample: bool = True
    max_new_tokens: int = 64
    repetition_penalty: float = 1.1
    temperature: float = 1
    top_k: int = 1
    top_p: float = 0.95
    # verbose: bool = False

    prompt_line_separator: str = "\n"

    @property
    def _llm_type(self) -> str:
        return "ChatHuggingFaceModel"

    def _format_message_as_text(self, message: BaseMessage) -> str:
        if isinstance(message, ChatMessage):
            message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
        elif isinstance(message, HumanMessage):
            message_text = self.human_message_template.format(message.content)
        elif isinstance(message, AIMessage):
            message_text = self.ai_message_template.format(message.content)
        elif isinstance(message, SystemMessage):
            message_text = self.system_message_template.format(message.content)
        else:
            raise ValueError(f"Got unknown type {message}")
        return message_text

    def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
        return self.prompt_line_separator.join(
            [self._format_message_as_text(message) for message in messages]
        )

    def _generate_stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> TextIteratorStreamer:

        prompt = self._format_messages_as_text(messages)

        tokens = self.tokenizer(
            prompt,
            # add_special_tokens=True,
            return_tensors="pt",
        ).input_ids.cuda()

        streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True, skip_special_tokens=True
        )

        # Generate output
        generation_kwargs = dict(
            inputs=tokens,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            max_new_tokens=self.max_new_tokens,
            streamer=streamer,
            num_return_sequences=1,
        )
        thread = Thread(target=self.generator.generate, kwargs=generation_kwargs)
        thread.start()

        return streamer

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        text = ""
        count = 0

        for new_text in streamer:
            if not new_text:
                continue

            text += new_text
            count += 1

            if run_manager:
                run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": count},
        )

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        text = ""
        count = 0

        for new_text in streamer:

            if not new_text:
                continue

            text += new_text
            count += 1

            # await asyncio.sleep(0)
            if run_manager:
                await run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": count},
        )

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:

        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        for new_text in streamer:
            if not new_text:
                continue

            yield ChatGenerationChunk(message=AIMessageChunk(content=new_text))

            if run_manager:
                run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Union[List[str], None] = None,
        run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:

        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        for new_text in streamer:
            if not new_text:
                continue

            yield ChatGenerationChunk(message=AIMessageChunk(content=new_text))

            if run_manager:
                await run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "generator": self.generator,
            "tokenizer": self.tokenizer,
        }

まとめ

以上、準備編でした。
次回から本気出す。たぶん。

15
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?