内容
Amazon Bedrockのナレッジベースを活用し、RAG(Retrieval-Augmented Generation)システムを構築します。Streamlitを用いたチャットアプリケーションを想定し、以下のような質問に回答できるシステムを構築します。ユーザーが回答を得た後、その情報の正確性や信頼性を確認したい場面があると思います。この記事では、比較的簡単に回答の出典元情報を表示する機能を追加する方法を記載します。
前提
事前に以下の記事のように、RAGで使用するデータをチャンク分割をしておきます。
このように分割したファイルをS3にアップロード後、同期処理を行い、ベクトルDBにデータを格納します。(下記は分割したファイルの例になります。)
チャット画面では、以下のように回答とあわせて出典元を表示する構成を作成してみます。出典元として表示されるのは、ヒットしたチャンク情報が含まれている元データのファイル名(S3に保存されたファイル名)です。
仕組み
boto3のretrieve_and_generate_stream
APIを使用して、ストリーミング形式で回答を受け取ります。
このAPIのレスポンスでは、次のような構造のデータが返ってきます。
location,s3location
内のuri
が取得出来れば良さそうです。
- output: 回答の一部が含まれる
- citation: 引用情報
- textResponsePart: 引用に基づいた回答部分
- retrievedReferences: 引用元チャンク情報
- content: 引用されたチャンクの内容
- location: チャンクの元ファイルの場所
{
"stream": [
{
"output": {
"text": "回答の一部..."
}
},
{
"output": {
"text": "回答の続き..."
}
},
{
"citation": {
"textResponsePart": {
"text": "回答の一部...引用に基づいた発言..."
},
"retrievedReferences": [
{
"content": {
"text": "引用元チャンク情報"
},
"location": {
"s3Location": {
"uri": "s3://bucket/path/to/source.pdf"
}
}
}
]
}
}
]
}
コード
下記の箇所は変数やBedrockに投げるプロンプト情報などを定義した後にretrieve_and_generate_stream
を使用してBedrockに問合せを行っているところです。結果はstream
に格納されます。
import streamlit as st
import boto3
import json
REGION = "ap-northeast-1"
MODEL_ARN = "arn:aws:bedrock:ap-northeast-1::foundation-model/anthropic.claude-3-haiku-20240307-v1:0"
KNOWLEDGEBASE_ID = '******'
st.set_page_config(page_title="テスト")
st.markdown("<p style='text-align: center; font-size: 36px; font-weight: bold;'>テスト</p>", unsafe_allow_html=True)
bedrock_agent = boto3.client("bedrock-agent-runtime", region_name=REGION)
user_question = "テスト用の質問"
context_prompt = f"""
以下はユーザーからの質問です:
<question>
{user_question}
</question>
"""
response = bedrock_agent.retrieve_and_generate_stream(
input={"text": context_prompt},
retrieveAndGenerateConfiguration={
"type": "KNOWLEDGE_BASE",
"knowledgeBaseConfiguration": {
"knowledgeBaseId": KNOWLEDGEBASE_ID,
"modelArn": MODEL_ARN,
}
}
)
stream = response.get("stream")
stream
に格納された応答結果から情報を取り出します。まずoutput
の中に回答が入っているのでこの結果をストリーム出力します。また出典元が格納されたcitation
キーがある場合はcitations
のリストに追加しておきます。
bot_response = ""
citations = []
if stream:
with st.chat_message("assistant"):
bot_area = st.empty()
for event in stream:
if "output" in event:
bot_response += event['output']['text']
bot_area.markdown(bot_response)
if "citation" in event:
citations.append(event["citation"])
s3Location
キーの中のurlから出典元情報を取り出します。
if citations:
sources = []
for citation in citations:
for ref in citation["retrievedReferences"]:
if "location" in ref and "s3Location" in ref["location"]:
uri = ref["location"]["s3Location"].get("uri")
if uri:
sources.append(uri)
重複する出典元がある場合は一つにまとめた後、s3://********
というパス部分を削除します。※s3://********
にはナレッジベースのデータソースを格納しているS3のパスが入ります。
unique_sources = list(set(sources))
cleaned_sources = [s.replace("s3://********/", "") for s in unique_sources]
ストリーム出力の後に出典元が表示されます。
if cleaned_sources:
sources_md = "\n\n* 出典元:\n" + "\n".join(f"- {s}" for s in cleaned_sources)
with st.chat_message("assistant"):
st.markdown(sources_md)