こんにちは。今日も通常運転です。シアトルじゃないです。
AIエージェントをBedrockで始めてみようという方向けの投稿です。
ポイント
ポイントはただ一つです
Return controlを使うこと
です!!!!!!
(いつの間にか「Return of control」じゃなくなってますね)
Bedrock Agentsがローンチされたタイミングでは、AIエージェントのツールに該当する処理をAWS Lambdaで実装する必要がありました。
ただ、Lambdaの呼び出しイベントやレスポンスの形式を気にしたり、試行錯誤しながら進めるには、Lambdaはしんどい。
Return controlを使うと、Tool部分を呼び出し元で実装できるので試行錯誤が容易です。
やってみよう
-
マネジメントコンソールでBedrockにアクセスし、「エージェント」を表示します
「エージェントを作成」をクリックします -
好きな名前をつけて「作成」をクリックします
-
エージェントビルダーという画面に遷移します
- モデルを選択(私はHaikuがお気に入り)
- エージェント向けの指示にAnthropicのシステムプロンプトを入力
(本当はもうちょっと考えて作るべきと思いますが、まずはこれで行きます)
画面上部の「保存」をクリックします(「保存して終了」ではありません)
-
続いて画面中央のアクショングループにある「追加」をクリックします
-
アクショングループ名は後で使うので、わかり易い名前(action-group-1)をつけます
-
Action group invocationは「Return control」を選択します
-
Web検索を行うツールを定義します
- Nameに「web_search」と入力
- Parametersに「search_query」を追加し、Descriptionを「検索キーワード」と入力、Requiredを「True」を選択
-
画面下部の「作成」をクリックします
-
エージェントビルダーの画面に戻るので「保存して終了」をクリックします
-
画面右側にある「準備」をクリックします
エージェントIDが後で必要になるのでメモします
これでAWS側の設定が完了です。
Streamlitで画面を作成
Streamlitで画面を作りましょう
-
ライブラリーをインストールします
Shellpip install boto3 requests streamlit
-
ここから始めます
import boto3 import streamlit as st st.title("Bedrock Agent") if prompt := st.chat_input(): with st.chat_message("user"): st.write(prompt)
-
toolを定義します
エージェントが使用するtoolを定義しますtoolsのキー(「action-group-1」と「web_search」)、web_search関数の引数「search_query」は、先程アクショングループで指定したものと合わせる必要があります
def web_search(search_query: str) -> dict: response = requests.post( "https://api.tavily.com/search", json={"query": search_query, "api_key": tavily_api_key}, ) return response.json() tools = { "action-group-1": {"web_search": web_search} }
「tavily_api_key」はこちらで取得してください。
-
エージェントIDとエイリアスを定義します
セッションIDは一意なIDを定義しますagent_id = "I3JWB4PYW5" agent_alias_id = "TSTALIASID" session_id = str(uuid.uuid4())
テスト実行の際のエイリアスは「TSTALIASID」です
-
Bedrock Agentを呼び出します
client = boto3.client("bedrock-agent-runtime") response = client.invoke_agent( agentId=agent_id, agentAliasId=agent_alias_id, sessionId=session_id, enableTrace=True, inputText=prompt, )
-
responseを巧みにパースします
result = tools[actionGroup][function](**parameter)
の部分で、Return controlの呼び出し情報をもとに、先程定義したToolを呼び出しています。ツールの呼び出し結果を「sessionState」にセットします。
sessionState = {} for event in response.get("completion"): if "returnControl" in event: returnControl: dict = event["returnControl"] invocation_id: str = returnControl["invocationId"] invocation_inputs: list = returnControl["invocationInputs"] sessionState["invocationId"] = invocation_id for input in invocation_inputs: function_invocation_input = input["functionInvocationInput"] actionGroup = function_invocation_input["actionGroup"] actionInvocationType = function_invocation_input["actionInvocationType"] function = function_invocation_input["function"] parameters = function_invocation_input["parameters"] parameter = {} for p in parameters: name = p["name"] type = p["type"] value = p["value"] parameter[name] = value result = tools[actionGroup][function](**parameter) returnControlInvocationResult = { "functionResult": { "actionGroup": actionGroup, "function": function, "responseBody": {"string": {"body": json.dumps(result)}}, } } sessionState["returnControlInvocationResults"] = [ returnControlInvocationResult ]
-
「sessionState」をセットしたリクエストを再度Bedrockに送信します
こうすることで、Toolの実行結果を受けての続きのエージェント処理が実行されますresponse = client.invoke_agent( agentId=agent_id, agentAliasId=agent_alias_id, sessionId=session_id, enableTrace=True, sessionState=sessionState, )
-
2度目のBedrock呼び出し結果から、最終回答を取得します
completion = "" for event in response.get("completion"): if "chunk" in event: chunk = event["chunk"] completion = completion + chunk["bytes"].decode()
これで、回答が取得できます
だいぶん駆け足での説明になりましたが、これでエージェントアプリが完成です。
実際には、Toolが複数回呼ばれたり、メッセージを画面に表示する処理も必要になります。
ソースコードの全体はこちらです。
import json
import uuid
import boto3
import requests
import streamlit as st
from markdownify import markdownify as md
from playwright.sync_api import sync_playwright
### 定数定義
agent_id = "I3JWB4PYW5"
agent_alias_id = "TSTALIASID"
tavily_api_key = "tvly-*****"
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
session_id = st.session_state.session_id
### ツール定義
def web_search(search_query: str) -> dict:
response = requests.post(
"https://api.tavily.com/search",
json={"query": search_query, "api_key": tavily_api_key},
)
return response.json()
tools = {
"action-group-1": {"web_search": web_search}
}
### UI
st.title("Bedrock Agent")
if "messages" not in st.session_state:
st.session_state.messages = []
messages = st.session_state.messages
for message in messages:
with st.chat_message(message["role"]):
st.write(message["content"])
def event_loop(response):
completion = ""
sessionState = {}
for event in response.get("completion"):
if "trace" in event:
trace = event["trace"]
with st.chat_message("assistant"):
with st.expander("Trace", expanded=False):
st.write(trace)
if "returnControl" in event:
returnControl: dict = event["returnControl"]
invocation_id: str = returnControl["invocationId"]
invocation_inputs: list = returnControl["invocationInputs"]
with st.chat_message("assistant"):
with st.expander("Return Control", expanded=False):
st.write(returnControl)
sessionState["invocationId"] = invocation_id
for input in invocation_inputs:
function_invocation_input = input["functionInvocationInput"]
actionGroup = function_invocation_input["actionGroup"]
actionInvocationType = function_invocation_input["actionInvocationType"]
function = function_invocation_input["function"]
parameters = function_invocation_input["parameters"]
parameter = {}
for p in parameters:
name = p["name"]
type = p["type"]
value = p["value"]
parameter[name] = value
result = tools[actionGroup][function](**parameter)
returnControlInvocationResult = {
"functionResult": {
"actionGroup": actionGroup,
"function": function,
"responseBody": {"string": {"body": json.dumps(result)}},
}
}
with st.chat_message("user"):
with st.expander("Return Control Result", expanded=False):
st.write(returnControlInvocationResult)
sessionState["returnControlInvocationResults"] = [
returnControlInvocationResult
]
if "chunk" in event:
chunk = event["chunk"]
completion = completion + chunk["bytes"].decode()
return completion, sessionState
if prompt := st.chat_input():
messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.container(border=True):
client = boto3.client("bedrock-agent-runtime")
response = client.invoke_agent(
agentId=agent_id,
agentAliasId=agent_alias_id,
sessionId=session_id,
enableTrace=True,
inputText=prompt,
)
completion, sessionState = event_loop(response)
while len(sessionState) > 0:
response = client.invoke_agent(
agentId=agent_id,
agentAliasId=agent_alias_id,
sessionId=session_id,
enableTrace=True,
sessionState=sessionState,
)
completion, sessionState = event_loop(response)
messages.append({"role": "assistant", "content": completion})
with st.chat_message("assistant"):
st.write(completion)
こちらには、更にURLをクロールする処理も追加したものをおいていますのでよろしければ参照ください。