今回は Chainlit を使い、お手軽にRAGアプリを実装してみます。
Chainlit はチャットボット形式のアプリを即興で作れるライブラリです。
こちらとVector Store対応した Oracle Database 23ai のFree版、そして LangChain を活用し、RAGアプリを作ります。
またEmbeddingモデルとLLMはローカルにダウンロードしたものを利用します。
実装するアプリのイメージですが、アプリにアクセスすると最初にRAGで使うPDFのアップロードを促します。
アップロードするとテキスト抽出、チャンク分割、Embeddingが行われ、Oracle Databaseに保存されます。
その後LLMに対して質問を投げると、アップロードしたファイルをRAGとして活用し、回答を生成します。
検証環境
- VM: OCI VM.Standard3.Flex (3 OCPU、48GB RAM)
- OS: Oracle Linux Server release 8.9
※すべて1つのサーバで実行します。
事前準備
Oracle Database 23ai Freeのインストール
Oracle Database 23ai Freeをインストールし、スキーマを作成します。
手順は下記をご参照ください。
- Oracle Database 23ai Freeインストール
-
アプリに使うスキーマの作成
- ネットワークACL設定、ディレクトリ・オブジェクト作成はSkipしてください。
また今回はベクトル索引として HNSW索引 を利用します。
HNSW索引の利用にはDBに対して初期化パラメータ vector_memory_size
の設定が必要ですので、以下手順に従い設定します。
-- sysユーザでCDBにログイン
SQL> show parameter vector_memory_size
NAME TYPE
------------------------------------ ---------------------------------
VALUE
------------------------------
vector_memory_size big integer
0
SQL> ALTER SYSTEM SET vector_memory_size=1G SCOPE=SPFILE;
SQL> shutdown immediate
SQL> startup
ORACLE
Total System Global Area 1603726344 bytes
Fixed Size 5360648 bytes
Variable Size 335544320 bytes
Database Buffers 184549376 bytes
Redo Buffers 4530176 bytes
Vector Memory Area 1073741824 bytes
SQL> show parameter vector_memory_size
NAME TYPE
------------------------------------ ---------------------------------
VALUE
------------------------------
vector_memory_size big integer
1G
OSパッケージ、Pythonパッケージのインストール
LangChainやEmbeddingモデル、LLMの利用に必要な各インストール作業は、以下記事の「パッケージのインストール」を参照してください。
Oracle Database 23ai + LangChainを使いローカル環境のみでRAGを試してみた
上記に加え Chainlit をインストールします。
対象のPython仮想環境を activate した状態で以下を実行します。
$ pip install chainlit
Chainlitへのアクセス設定
Chainlit により作成したアプリへインターネット経由でアクセスする場合、以下のようにファイアウォールの許可設定とポートフォワーディングの設定を追加します。
80番を8000番へポートフォワーディングしているのは、Chainlit を起動したときの待ち受けがデフォルト 8000 番ポートだからです。
$ sudo su -
# http通信を許可
$ firewall-cmd --add-service=http
# 80ポートへの通信を8000にフォワード
$ firewall-cmd --add-forward-port=port=80:proto=tcp:toport=8000
# 設定保存
$ firewall-cmd --runtime-to-permanent
# 設定リロード
$ firewall-cmd --reload
# 内容確認
$ firewall-cmd --list-all
上記に加え、OCIなどクラウド上で試す場合は、80番ポートに対してクラウド側のイングレス通信許可設定を追加します。
Chainlit のパスワード認証を設定
インターネット上に公開する想定のため、一応 Chainlit の機能を使ってパスワード認証を有効化しておきます。
以下コマンドを実行し、出力されたAuthトークンを .env
ファイルに保存します。
$ chainlit create-secret
Copy the following secret into your .env file. Once it is set, changing it will logout all users with active sessions.
CHAINLIT_AUTH_SECRET="...(略)..."
Chainlit + LangChainを使ったアプリの実装
コード全体は以下のようになります。
import os
from dotenv import load_dotenv
load_dotenv()
import oracledb
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import oraclevs
from langchain_community.vectorstores.oraclevs import OracleVS
from langchain.chains import RetrievalQA
from langchain_community.llms.llamacpp import LlamaCpp
from langchain_core.prompts import PromptTemplate
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain_community.vectorstores.utils import DistanceStrategy
import chainlit as cl
# Oracle DB接続情報
user = os.environ['username']
pwd = os.environ['password']
dsn = os.environ['service']
# chainlitアカウント情報
chainlit_user = os.environ['chainlit_user']
chainlit_pwd = os.environ['chainlit_pwd']
# ベクトルデータ格納先テーブル
table_name = "ovs"
# embeddingモデルの読み込み
embedding_model = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large"
)
# モデルのパス
model_path = "/home/oracle/lctest/vicuna-13b-v1.5.Q4_K_M.gguf"
# モデルの読み込み
llm = LlamaCpp(
model_path=model_path,
n_ctx=2048,
max_tokens=4096,
temperature=0.3
)
# Chainlitを使ったパスワード認証の設定
@cl.password_auth_callback
def auth_callback(username: str, password: str):
# Fetch the user matching username from your database
# and compare the hashed password with the value stored in the database
if (username, password) == (chainlit_user, chainlit_pwd):
return cl.User(
identifier=chainlit_user, metadata={"role": "user", "provider": "credentials"}
)
else:
return None
@cl.on_chat_start
async def on_chat_start():
# ファイルアップロードの処理
files = None
while files is None:
# chainlitの機能に、ファイルをアップロードさせるメソッドがある。
files = await cl.AskFileMessage(
# ファイルの最大サイズ
max_size_mb=20,
# ファイルをアップロードさせる画面のメッセージ
content="PDFを選択してください。",
# PDFファイルを指定する
accept=["application/pdf"],
# タイムアウトなし
raise_on_timeout=False,
).send()
# アップロードされたファイルの読み込み
loader = PyPDFLoader(files[0].path)
# テキスト抽出
pages = loader.load_and_split()
# テキスト分割
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
docs = text_splitter.split_documents(pages)
# 分割結果を一時保存
lc_docs = []
for doc in docs:
lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""), metadata={'source': files[0].name}))
# Oracle Vector Databaseに接続
connection = oracledb.connect(user = user, password = pwd, dsn = dsn)
# DBに保存
ovs = OracleVS.from_documents(
lc_docs,
embedding_model,
client=connection,
table_name=table_name,
distance_strategy=DistanceStrategy.COSINE,
)
# ベクトル索引 (HNSW索引) 作成
oraclevs.create_index(connection, ovs, params={"idx_name": "hnsw_idx1", "idx_type": "HNSW"})
await cl.Message(content=f"`{files[0].name}` の準備が完了しました。").send()
# プロンプトテンプレートの定義
template = """以下の文脈を利用して、最後の質問に簡潔に答えてください。答えがわからない場合は、わからないと答えてください。
{context}
質問: {question}
回答(日本語):"""
# プロンプトの設定
question_prompt = PromptTemplate(
template=template, # プロンプトテンプレートをセット
input_variables=["question"] # プロンプトに挿入する変数
)
# RAG設定
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=ovs.as_retriever(search_kwargs={'k': 5}),
return_source_documents=True,
chain_type_kwargs={"prompt": question_prompt}
)
# セッション登録
cl.user_session.set("runnable", qa)
@cl.on_message
async def on_message(message: cl.Message):
# セッション情報から設定を読み込み
runnable=cl.user_session.get("runnable")
# Chainlit設定
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached=True
# 回答生成
res=await runnable.ainvoke(message.content, callbacks=[cb])
# 回答表示
await cl.Message(content=f"\nAnswer:\n"+res['result']).send()
# 引用元表示
await cl.Message(content=f"\nSource:\n"+str(res['source_documents'])).send()
主要なパートについて説明します。
# Oracle DB接続情報
user = os.environ['username']
pwd = os.environ['password']
dsn = os.environ['service']
# chainlitアカウント情報
chainlit_user = os.environ['chainlit_user']
chainlit_pwd = os.environ['chainlit_pwd']
上記は .env
ファイルに記載されている変数を読み込んでいます。
そのため、事前に .env
ファイルに対応する変数を定義しておく必要があります。
# Chainlitを使ったパスワード認証の設定
@cl.password_auth_callback
def auth_callback(username: str, password: str):
# Fetch the user matching username from your database
# and compare the hashed password with the value stored in the database
if (username, password) == (chainlit_user, chainlit_pwd):
return cl.User(
identifier=chainlit_user, metadata={"role": "user", "provider": "credentials"}
)
else:
return None
このパートは Chainlit の機能を使ったパスワード認証の設定です。
.env
ファイルから読み込んだ chainlit_user
、chainlit_pwd
を使い、アカウント照合を行います。
もちろんアカウント情報をDBに保存しておき、そちらの情報と照合するような実装も可能です。
@cl.on_chat_start
async def on_chat_start():
# ファイルアップロードの処理
files = None
while files is None:
# chainlitの機能に、ファイルをアップロードさせるメソッドがある。
files = await cl.AskFileMessage(
# ファイルの最大サイズ
max_size_mb=20,
# ファイルをアップロードさせる画面のメッセージ
content="PDFを選択してください。",
# PDFファイルを指定する
accept=["application/pdf"],
# タイムアウトなし
raise_on_timeout=False,
).send()
# アップロードされたファイルの読み込み
loader = PyPDFLoader(files[0].path)
# テキスト抽出
pages = loader.load_and_split()
# テキスト分割
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
docs = text_splitter.split_documents(pages)
# 分割結果を一時保存
lc_docs = []
for doc in docs:
lc_docs.append(Document(page_content=doc.page_content.replace("\n", ""), metadata={'source': files[0].name}))
# Oracle Vector Databaseに接続
connection = oracledb.connect(user = user, password = pwd, dsn = dsn)
# DBに保存
ovs = OracleVS.from_documents(
lc_docs,
embedding_model,
client=connection,
table_name=table_name,
distance_strategy=DistanceStrategy.COSINE,
)
# ベクトル索引 (HNSW索引) 作成
oraclevs.create_index(connection, ovs, params={"idx_name": "hnsw_idx1", "idx_type": "HNSW"})
await cl.Message(content=f"`{files[0].name}` の準備が完了しました。").send()
こちらは最初にチャットへアクセスしたときに、PDFファイルのアップロードを促す処理です。
またアップロードされたPDFファイルの内容をEmbeddingし、Oracle Databaseに結果を保存する流れとなっています。
最後に HNSW索引 を作成し、問題なく完了したら「~.pdf の準備が完了しました。」というメッセージを出力します。
なおこの処理では、PDFファイルがアップロードされる度にテーブルを作り直してデータを挿入しています。
# プロンプトテンプレートの定義
template = """以下の文脈を利用して、最後の質問に簡潔に答えてください。答えがわからない場合は、わからないと答えてください。
{context}
質問: {question}
回答(日本語):"""
# プロンプトの設定
question_prompt = PromptTemplate(
template=template, # プロンプトテンプレートをセット
input_variables=["question"] # プロンプトに挿入する変数
)
# RAG設定
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=ovs.as_retriever(search_kwargs={'k': 5}),
return_source_documents=True,
chain_type_kwargs={"prompt": question_prompt}
)
# セッション登録
cl.user_session.set("runnable", qa)
PDFアップロード処理が無事に完了したら、LangChainの機能を使いRAGの設定を行っています。
最後にインスタンス化したRAG用のオブジェクトをセッションに登録しています。
@cl.on_message
async def on_message(message: cl.Message):
# セッション情報から設定を読み込み
runnable=cl.user_session.get("runnable")
# Chainlit設定
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached=True
# 回答生成
res=await runnable.ainvoke(message.content, callbacks=[cb])
# 回答表示
await cl.Message(content=f"\nAnswer:\n"+res['result']).send()
# 引用元表示
await cl.Message(content=f"\nSource:\n"+str(res['source_documents'])).send()
このパートはユーザがチャットに投稿すると呼び出される処理を記述しています。
セッションに登録してあるRAG用オブジェクトを読み込み、投稿された質問をLLMに連携して回答を生成しています。
また生成された回答を、引用元の情報と共にチャット上へ出力しています。
実際に動作確認してみます。
読み込ませるドキュメントには下記Oracle公式のブログ記事を使います。
Oracle Database 23aiを発表: 提供開始
こちらを pdf
ファイルとして保存し、チャットからアップロードします。
「Oracle Database 23aiの主要な新機能は何ですか?」という質問をしてみます。
「Oracle Database 23aiが重視する製品コンセプトは何ですか?」という質問をしてみます。
今回は日本語を扱う上で品質の良いEmbeddingモデル intfloat/multilingual-e5-large
を使っているためか、かなり的を得た回答になっています。
以上が Oracle Database 23ai + LangChain + Chainlit を使ったRAGアプリの実装です。
画面を作るのは割と手間に思われるかもしれませんが、Chainlit を使えば手軽にRAGアプリが作れてしまいますね。