LoginSignup
3
1

Oracle Database 23ai + LangChain + Chainlitで手軽にRAGアプリを実装

Last updated at Posted at 2024-05-14

今回は Chainlit を使い、お手軽にRAGアプリを実装してみます。
Chainlit はチャットボット形式のアプリを即興で作れるライブラリです。
こちらとVector Store対応した Oracle Database 23ai のFree版、そして LangChain を活用し、RAGアプリを作ります。
またEmbeddingモデルとLLMはローカルにダウンロードしたものを利用します。

実装するアプリのイメージですが、アプリにアクセスすると最初にRAGで使うPDFのアップロードを促します。
WS000003.JPG

アップロードするとテキスト抽出、チャンク分割、Embeddingが行われ、Oracle Databaseに保存されます。
その後LLMに対して質問を投げると、アップロードしたファイルをRAGとして活用し、回答を生成します。
WS000001.JPG

検証環境

  • VM: OCI VM.Standard3.Flex (3 OCPU、48GB RAM)
  • OS: Oracle Linux Server release 8.9

※すべて1つのサーバで実行します。

事前準備

Oracle Database 23ai Freeのインストール

Oracle Database 23ai Freeをインストールし、スキーマを作成します。
手順は下記をご参照ください。

また今回はベクトル索引として HNSW索引 を利用します。
HNSW索引の利用にはDBに対して初期化パラメータ vector_memory_size の設定が必要ですので、以下手順に従い設定します。

-- sysユーザでCDBにログイン
SQL> show parameter vector_memory_size

NAME     TYPE
------------------------------------ ---------------------------------
VALUE
------------------------------
vector_memory_size     big integer
0

SQL> ALTER SYSTEM SET vector_memory_size=1G SCOPE=SPFILE;
SQL> shutdown immediate
SQL> startup

ORACLE

Total System Global Area 1603726344 bytes
Fixed Size    5360648 bytes
Variable Size  335544320 bytes
Database Buffers  184549376 bytes
Redo Buffers    4530176 bytes
Vector Memory Area 1073741824 bytes

SQL> show parameter vector_memory_size

NAME     TYPE
------------------------------------ ---------------------------------
VALUE
------------------------------
vector_memory_size     big integer
1G

OSパッケージ、Pythonパッケージのインストール

LangChainやEmbeddingモデル、LLMの利用に必要な各インストール作業は、以下記事の「パッケージのインストール」を参照してください。
Oracle Database 23ai + LangChainを使いローカル環境のみでRAGを試してみた

上記に加え Chainlit をインストールします。
対象のPython仮想環境を activate した状態で以下を実行します。

$ pip install chainlit

Chainlitへのアクセス設定

Chainlit により作成したアプリへインターネット経由でアクセスする場合、以下のようにファイアウォールの許可設定とポートフォワーディングの設定を追加します。
80番を8000番へポートフォワーディングしているのは、Chainlit を起動したときの待ち受けがデフォルト 8000 番ポートだからです。

$ sudo su -
# http通信を許可
$ firewall-cmd --add-service=http
# 80ポートへの通信を8000にフォワード
$ firewall-cmd --add-forward-port=port=80:proto=tcp:toport=8000
# 設定保存
$ firewall-cmd --runtime-to-permanent
# 設定リロード
$ firewall-cmd --reload
# 内容確認
$ firewall-cmd --list-all

上記に加え、OCIなどクラウド上で試す場合は、80番ポートに対してクラウド側のイングレス通信許可設定を追加します。

Chainlit のパスワード認証を設定

インターネット上に公開する想定のため、一応 Chainlit の機能を使ってパスワード認証を有効化しておきます。
以下コマンドを実行し、出力されたAuthトークンを .env ファイルに保存します。

$ chainlit create-secret
Copy the following secret into your .env file. Once it is set, changing it will logout all users with active sessions.
CHAINLIT_AUTH_SECRET="...(略)..."

Chainlit + LangChainを使ったアプリの実装

コード全体は以下のようになります。

import os
from dotenv import load_dotenv
load_dotenv()

import oracledb
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import oraclevs
from langchain_community.vectorstores.oraclevs import OracleVS
from langchain.chains import RetrievalQA
from langchain_community.llms.llamacpp import LlamaCpp
from langchain_core.prompts import PromptTemplate
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain_community.vectorstores.utils import DistanceStrategy

import chainlit as cl

# Oracle DB接続情報
user = os.environ['username']
pwd = os.environ['password']
dsn = os.environ['service']

# chainlitアカウント情報
chainlit_user = os.environ['chainlit_user']
chainlit_pwd = os.environ['chainlit_pwd']

# ベクトルデータ格納先テーブル
table_name = "ovs"

# embeddingモデルの読み込み
embedding_model = HuggingFaceEmbeddings(
    model_name="intfloat/multilingual-e5-large"
)

# モデルのパス
model_path = "/home/oracle/lctest/vicuna-13b-v1.5.Q4_K_M.gguf"

# モデルの読み込み
llm = LlamaCpp(
    model_path=model_path,
    n_ctx=2048,
    max_tokens=4096,
    temperature=0.3
)

# Chainlitを使ったパスワード認証の設定
@cl.password_auth_callback
def auth_callback(username: str, password: str):
    # Fetch the user matching username from your database
    # and compare the hashed password with the value stored in the database
    if (username, password) == (chainlit_user, chainlit_pwd):
        return cl.User(
            identifier=chainlit_user, metadata={"role": "user", "provider": "credentials"}
        )
    else:
        return None

@cl.on_chat_start
async def on_chat_start():
    # ファイルアップロードの処理
    files = None
    while files is None:
        # chainlitの機能に、ファイルをアップロードさせるメソッドがある。
        files = await cl.AskFileMessage(
            # ファイルの最大サイズ
            max_size_mb=20,
            # ファイルをアップロードさせる画面のメッセージ
            content="PDFを選択してください。",
            # PDFファイルを指定する
            accept=["application/pdf"],
            # タイムアウトなし
            raise_on_timeout=False,
        ).send()    
     
    # アップロードされたファイルの読み込み
    loader = PyPDFLoader(files[0].path)
    
    # テキスト抽出
    pages = loader.load_and_split()

    # テキスト分割
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
    docs = text_splitter.split_documents(pages)

    # 分割結果を一時保存
    lc_docs = []
    for doc in docs:
        lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""), metadata={'source': files[0].name}))
    
    # Oracle Vector Databaseに接続
    connection = oracledb.connect(user = user, password = pwd, dsn = dsn) 
    
    # DBに保存
    ovs = OracleVS.from_documents(
        lc_docs,
        embedding_model,
        client=connection,
        table_name=table_name,
        distance_strategy=DistanceStrategy.COSINE,
    )

    # ベクトル索引 (HNSW索引) 作成
    oraclevs.create_index(connection, ovs, params={"idx_name": "hnsw_idx1", "idx_type": "HNSW"})
 
    await cl.Message(content=f"`{files[0].name}` の準備が完了しました。").send()
        
    # プロンプトテンプレートの定義
    template = """以下の文脈を利用して、最後の質問に簡潔に答えてください。答えがわからない場合は、わからないと答えてください。
{context}
質問: {question}
回答(日本語):"""

    # プロンプトの設定
    question_prompt = PromptTemplate(
        template=template, # プロンプトテンプレートをセット
        input_variables=["question"] # プロンプトに挿入する変数
    )

    # RAG設定
    qa = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=ovs.as_retriever(search_kwargs={'k': 5}),
        return_source_documents=True,
        chain_type_kwargs={"prompt": question_prompt}
    )

    # セッション登録
    cl.user_session.set("runnable", qa)

@cl.on_message
async def on_message(message: cl.Message):
    # セッション情報から設定を読み込み
    runnable=cl.user_session.get("runnable")
    
    # Chainlit設定
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=True,
        answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    
    cb.answer_reached=True
    
    # 回答生成
    res=await runnable.ainvoke(message.content, callbacks=[cb])
    
    # 回答表示
    await cl.Message(content=f"\nAnswer:\n"+res['result']).send()
    
    # 引用元表示
    await cl.Message(content=f"\nSource:\n"+str(res['source_documents'])).send()

主要なパートについて説明します。

# Oracle DB接続情報
user = os.environ['username']
pwd = os.environ['password']
dsn = os.environ['service']

# chainlitアカウント情報
chainlit_user = os.environ['chainlit_user']
chainlit_pwd = os.environ['chainlit_pwd']

上記は .env ファイルに記載されている変数を読み込んでいます。
そのため、事前に .env ファイルに対応する変数を定義しておく必要があります。

# Chainlitを使ったパスワード認証の設定
@cl.password_auth_callback
def auth_callback(username: str, password: str):
    # Fetch the user matching username from your database
    # and compare the hashed password with the value stored in the database
    if (username, password) == (chainlit_user, chainlit_pwd):
        return cl.User(
            identifier=chainlit_user, metadata={"role": "user", "provider": "credentials"}
        )
    else:
        return None

このパートは Chainlit の機能を使ったパスワード認証の設定です。
.env ファイルから読み込んだ chainlit_userchainlit_pwd を使い、アカウント照合を行います。
もちろんアカウント情報をDBに保存しておき、そちらの情報と照合するような実装も可能です。

@cl.on_chat_start
async def on_chat_start():
    # ファイルアップロードの処理
    files = None
    while files is None:
        # chainlitの機能に、ファイルをアップロードさせるメソッドがある。
        files = await cl.AskFileMessage(
            # ファイルの最大サイズ
            max_size_mb=20,
            # ファイルをアップロードさせる画面のメッセージ
            content="PDFを選択してください。",
            # PDFファイルを指定する
            accept=["application/pdf"],
            # タイムアウトなし
            raise_on_timeout=False,
        ).send()    
     
    # アップロードされたファイルの読み込み
    loader = PyPDFLoader(files[0].path)
    
    # テキスト抽出
    pages = loader.load_and_split()

    # テキスト分割
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
    docs = text_splitter.split_documents(pages)

    # 分割結果を一時保存
    lc_docs = []
    for doc in docs:
        lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""), metadata={'source': files[0].name}))
    
    # Oracle Vector Databaseに接続
    connection = oracledb.connect(user = user, password = pwd, dsn = dsn) 
    
    # DBに保存
    ovs = OracleVS.from_documents(
        lc_docs,
        embedding_model,
        client=connection,
        table_name=table_name,
        distance_strategy=DistanceStrategy.COSINE,
    )

    # ベクトル索引 (HNSW索引) 作成
    oraclevs.create_index(connection, ovs, params={"idx_name": "hnsw_idx1", "idx_type": "HNSW"})
 
    await cl.Message(content=f"`{files[0].name}` の準備が完了しました。").send()

こちらは最初にチャットへアクセスしたときに、PDFファイルのアップロードを促す処理です。
またアップロードされたPDFファイルの内容をEmbeddingし、Oracle Databaseに結果を保存する流れとなっています。
最後に HNSW索引 を作成し、問題なく完了したら「~.pdf の準備が完了しました。」というメッセージを出力します。
なおこの処理では、PDFファイルがアップロードされる度にテーブルを作り直してデータを挿入しています。

    # プロンプトテンプレートの定義
    template = """以下の文脈を利用して、最後の質問に簡潔に答えてください。答えがわからない場合は、わからないと答えてください。
{context}
質問: {question}
回答(日本語):"""

    # プロンプトの設定
    question_prompt = PromptTemplate(
        template=template, # プロンプトテンプレートをセット
        input_variables=["question"] # プロンプトに挿入する変数
    )

    # RAG設定
    qa = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=ovs.as_retriever(search_kwargs={'k': 5}),
        return_source_documents=True,
        chain_type_kwargs={"prompt": question_prompt}
    )

    # セッション登録
    cl.user_session.set("runnable", qa)

PDFアップロード処理が無事に完了したら、LangChainの機能を使いRAGの設定を行っています。
最後にインスタンス化したRAG用のオブジェクトをセッションに登録しています。

@cl.on_message
async def on_message(message: cl.Message):
    # セッション情報から設定を読み込み
    runnable=cl.user_session.get("runnable")
    
    # Chainlit設定
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=True,
        answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    
    cb.answer_reached=True
    
    # 回答生成
    res=await runnable.ainvoke(message.content, callbacks=[cb])
    
    # 回答表示
    await cl.Message(content=f"\nAnswer:\n"+res['result']).send()
    
    # 引用元表示
    await cl.Message(content=f"\nSource:\n"+str(res['source_documents'])).send()

このパートはユーザがチャットに投稿すると呼び出される処理を記述しています。
セッションに登録してあるRAG用オブジェクトを読み込み、投稿された質問をLLMに連携して回答を生成しています。
また生成された回答を、引用元の情報と共にチャット上へ出力しています。

実際に動作確認してみます。
読み込ませるドキュメントには下記Oracle公式のブログ記事を使います。
Oracle Database 23aiを発表: 提供開始
こちらを pdf ファイルとして保存し、チャットからアップロードします。
WS000000.JPG

「Oracle Database 23aiの主要な新機能は何ですか?」という質問をしてみます。
WS000001.JPG

「Oracle Database 23aiが重視する製品コンセプトは何ですか?」という質問をしてみます。
WS000002.JPG

今回は日本語を扱う上で品質の良いEmbeddingモデル intfloat/multilingual-e5-large を使っているためか、かなり的を得た回答になっています。

以上が Oracle Database 23ai + LangChain + Chainlit を使ったRAGアプリの実装です。
画面を作るのは割と手間に思われるかもしれませんが、Chainlit を使えば手軽にRAGアプリが作れてしまいますね。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1