LoginSignup
0
1

LangChainとStreamlitでFunction Calling対応のチャットボットを作る

Posted at

LangchainでFunction Callingに対応したAgentにMemoryやToolを組み合わせ,Streamlitで動作するチャットボットを作成しました.

このコードではユーザーのプロフィール作成を手伝ってくれるチャットボットを作成しています.ユーザーのプロフィール(namefavoritesdislikes)を引数に受け取るToolを定義しこのToolを呼び出すように指示することで,ユーザーのプロフィールを聞き出してくれるようになります.

このプログラム実行したい場合は,README.mdを参考にしてください.OPENAI APIキーは.envファイルに記入します.

カスタムToolの作成

実際のToolは次のようになっています.このToolがFunction Callingで呼び出されます.

tools.py
class GenerateProfileSchema(BaseModel):
    name: str = Field(..., description="name of user")
    favorites: str = Field(..., description="favorite of user")
    dislikes: str = Field(..., description="dislikes of user")


class GenerateProfileTool(BaseTool):
    name = "generate_profile"
    description = "Use this tool to generate a profile."
    args_schema: Type[GenerateProfileSchema] = GenerateProfileSchema

    def _run(self, name, favorites, dislikes):
        print("generating profile")
        return {
            "name": name,
            "favorites": favorites,
            "dislikes": dislikes,
        }

    async def _arun(self, name, favorites, dislikes):
        raise NotImplementedError("Not implemented yet")

GenerateProfileSchemaで引数を定義し,GenerateProfileToolで実際の処理を定義します.

GenerateProfileSchemaではpydanticを使って引数を定義します.

GenerateProfileToolで実際に行う処理を定義します.このコードではほとんど何もしていませんが,実際にはここで他のAPIやデータベースとやり取りするようなイメージです.
GenerateProfileToolには_run_arunの2つのメソッドがあり,_runは同期処理,_arunは非同期処理を定義します.今回は同期処理のみを利用していますが,その場合でも非同期処理を定義しておかないといけないようなので,NotImplementedErrorを返すようにしています.

ちなみにLangChainではなくopenaiライブラリを使うとしたらTool(というよりFunction)は次のように定義します.

{
  "name": "generate_profile",
  "description": "Use this tool to generate a profile.",
  "parameters": {
    "type": "object",
    "properties": {
      "name": {
        "type": "string",
        "description": "name of user",
      },
      "favorites": {
        "type": "string",
        "description": "favorite of user",
      },
      "dislikes": {
        "type": "string",
        "description": "dislikes of user",
      },
    },
    "required": ["name", "favorites", "dislikes"],
  },
}

Memoryで会話の履歴を保持する

会話の履歴を保持するためにConversationBufferMemoryを使っています.またStreamlitを使っているので,st.session_stateを使ってmemoryを保存しています.

main.py
def get_memory():
    # get memory from session state if it exists
    memory = st.session_state.get("memory", None)

    # initialize memory if it doesn't exist
    if memory is None:
        memory = ConversationBufferMemory(memory_key="memory", return_messages=True)

    return memory

LangChainのStreamlit Callback

チャットボットの発言をストリーミング表示させるために,LangChainのStreamlit用Callbackを使っています.

main.py
st_cb = StreamlitCallbackHandler(st.empty())
response = agent.run(prompt, callbacks=[st_cb])

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1