LangchainでFunction Callingに対応したAgentにMemoryやToolを組み合わせ,Streamlitで動作するチャットボットを作成しました.
このコードではユーザーのプロフィール作成を手伝ってくれるチャットボットを作成しています.ユーザーのプロフィール(name
,favorites
,dislikes
)を引数に受け取るToolを定義しこのToolを呼び出すように指示することで,ユーザーのプロフィールを聞き出してくれるようになります.
このプログラム実行したい場合は,README.md
を参考にしてください.OPENAI APIキーは.env
ファイルに記入します.
カスタムToolの作成
実際のToolは次のようになっています.このToolがFunction Callingで呼び出されます.
class GenerateProfileSchema(BaseModel):
name: str = Field(..., description="name of user")
favorites: str = Field(..., description="favorite of user")
dislikes: str = Field(..., description="dislikes of user")
class GenerateProfileTool(BaseTool):
name = "generate_profile"
description = "Use this tool to generate a profile."
args_schema: Type[GenerateProfileSchema] = GenerateProfileSchema
def _run(self, name, favorites, dislikes):
print("generating profile")
return {
"name": name,
"favorites": favorites,
"dislikes": dislikes,
}
async def _arun(self, name, favorites, dislikes):
raise NotImplementedError("Not implemented yet")
GenerateProfileSchema
で引数を定義し,GenerateProfileTool
で実際の処理を定義します.
GenerateProfileSchema
ではpydantic
を使って引数を定義します.
GenerateProfileTool
で実際に行う処理を定義します.このコードではほとんど何もしていませんが,実際にはここで他のAPIやデータベースとやり取りするようなイメージです.
GenerateProfileTool
には_run
と_arun
の2つのメソッドがあり,_run
は同期処理,_arun
は非同期処理を定義します.今回は同期処理のみを利用していますが,その場合でも非同期処理を定義しておかないといけないようなので,NotImplementedError
を返すようにしています.
ちなみにLangChainではなくopenai
ライブラリを使うとしたらTool(というよりFunction)は次のように定義します.
{
"name": "generate_profile",
"description": "Use this tool to generate a profile.",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "name of user",
},
"favorites": {
"type": "string",
"description": "favorite of user",
},
"dislikes": {
"type": "string",
"description": "dislikes of user",
},
},
"required": ["name", "favorites", "dislikes"],
},
}
Memoryで会話の履歴を保持する
会話の履歴を保持するためにConversationBufferMemory
を使っています.またStreamlitを使っているので,st.session_state
を使ってmemory
を保存しています.
def get_memory():
# get memory from session state if it exists
memory = st.session_state.get("memory", None)
# initialize memory if it doesn't exist
if memory is None:
memory = ConversationBufferMemory(memory_key="memory", return_messages=True)
return memory
LangChainのStreamlit Callback
チャットボットの発言をストリーミング表示させるために,LangChainのStreamlit用Callbackを使っています.
st_cb = StreamlitCallbackHandler(st.empty())
response = agent.run(prompt, callbacks=[st_cb])