LoginSignup
8
11

LangChainのReACTについて中身を確認したメモ

Posted at

今回調べたこと

LangChainの公式ドキュメントにもある、以下の参考コードのようなReAct(Agent)を用いた生成について。
ReActの論文などで何となくやっていることは分かっているつもりですが、裏で何が起こってるか曖昧な部分があったので、ソースコードを元に実際の処理を確認しました。

公式ドキュメント

ソースコード

ReAct論文

参考コード

Agent使用例
from langchain import OpenAI
from langchain import Wikipedia
from langchain.agents.react.base import DocstoreExplorer

# 使用する言語モデルの定義
llm = OpenAI(
    openai_api_key=apikey,
    temperature=0
)
# 使用するtool群の定義
docstore=DocstoreExplorer(Wikipedia())
tools = [
    Tool(
        name="Search",
        func=docstore.search,
        description='search wikipedia'
    ),
    Tool(
        name="Lookup",
        func=docstore.lookup,
        description='lookup a term in wikipedia'
    )
]
# 使用するagentの定義
docstore_agent = initialize_agent(
    tools, 
    llm, 
    agent="react-docstore", 
    verbose=True,
    max_iterations=3
)
# 質問(query)に対する応答を生成
query = ""   # 質問内容
docstore_agent(query)

Agentの初期化

まずは使用するagentの初期化initialize_agent()の部分について。

initialize_agent()
def initialize_agent(
    tools: Sequence[BaseTool],
    llm: BaseLanguageModel,
    agent: Optional[AgentType] = None,
    callback_manager: Optional[BaseCallbackManager] = None,
    agent_path: Optional[str] = None,
    agent_kwargs: Optional[dict] = None,
    **kwargs: Any,
) -> AgentExecutor:
    #...省略
        agent_cls = AGENT_TO_CLASS[agent]
        agent_kwargs = agent_kwargs or {}
        agent_obj = agent_cls.from_llm_and_tools(
            llm, tools, callback_manager=callback_manager, **agent_kwargs
        )
    #...省略
    return AgentExecutor.from_agent_and_tools(
        agent=agent_obj,
        tools=tools,
        callback_manager=callback_manager,
        **kwargs,
    )

やっているのは関数名の通り、agentの初期化。
引数のagent名(例だと"react-docstore")に対応するAgentクラスをagent_clsでインスタンス化。
Agent.from_llm_and_tools()ではAgentクラスのcreate_prompt()_get_default_output_parser()でpromptとparserを用意して、前で定義されているtoolsと共にagentの中身を確定。
最後にAgentExecutor.from_agent_and_tools()でprompt, parser, toolsを元にAgentExecutorクラスを返していました。

Agent.from_llm_and_tools()
    #...省略
    llm_chain = LLMChain(
        llm=llm,
        prompt=cls.create_prompt(tools),   # prompt設定
        callback_manager=callback_manager,
    )
    tool_names = [tool.name for tool in tools]
    _output_parser = output_parser or cls._get_default_output_parser()   #parser設定
    return cls(
        llm_chain=llm_chain,
        allowed_tools=tool_names,
        output_parser=_output_parser,
        **kwargs,
    )
AgentExecutor.from_agent_and_tools()
    @classmethod
    def from_agent_and_tools(
        cls,
        agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
        tools: Sequence[BaseTool],
        callback_manager: Optional[BaseCallbackManager] = None,
        **kwargs: Any,
    ) -> AgentExecutor:
        """Create from agent and tools."""
        return cls(
            agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
        )

Agentによる生成の流れ

確認対象であるdocstore_agent(query)の部分です。
AgentExecutorクラスを呼び出して生成をしているので、__call()__で生成しているはず。
AgentExecutorには__call()__が無かったので、子クラスからChain.__call()__が呼び出されていました。

Chain.__call__()
def __call__(
    self,
    inputs: Union[Dict[str, Any], Any],
    return_only_outputs: bool = False,
    callbacks: Callbacks = None,
) -> Dict[str, Any]:
    #...省略
    inputs = self.prep_inputs(inputs)
    #...省略
    try:
        outputs = (
            self._call(inputs, run_manager=run_manager)
            if new_arg_supported
            else self._call(inputs)
        )
    #...省略
    return self.prep_outputs(inputs, outputs, return_only_outputs)

細かい部分を無視すればself.prep_inputs()で入力を処理して、self._call()がoutputsになるみたい。
(self._call()についてはChainクラスではなく、AgentExecutorクラスのもの)
最終的に__call()__が返しているのはself.prep_outputs(inputs, outputs, return_only_outputs)

まずself.prep_inputs()についてはPrompt Templateに従った入力になっているかの確認でした。(コードは省略)
出力を生成しているself._call()の部分は以下の通り。

AgentExecutor._call()
def _call(
    self,
    inputs: Dict[str, str],
    run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
    """Run text through and get agent response."""
    # Construct a mapping of tool name to tool for easy lookup
    name_to_tool_map = {tool.name: tool for tool in self.tools}
    # We construct a mapping from each tool to a color, used for logging.
    color_mapping = get_color_mapping(
        [tool.name for tool in self.tools], excluded_colors=["green"]
    )
    intermediate_steps: List[Tuple[AgentAction, str]] = []
    # Let's start tracking the number of iterations and time elapsed
    iterations = 0
    time_elapsed = 0.0
    start_time = time.time()
    # We now enter the agent loop (until it returns something).
    while self._should_continue(iterations, time_elapsed):
        next_step_output = self._take_next_step(
            name_to_tool_map,
            color_mapping,
            inputs,
            intermediate_steps,
            run_manager=run_manager,
        )
        if isinstance(next_step_output, AgentFinish):
            return self._return(
                next_step_output, intermediate_steps, run_manager=run_manager
            )

        intermediate_steps.extend(next_step_output)
        if len(next_step_output) == 1:
            next_step_action = next_step_output[0]
            # See if tool should return directly
            tool_return = self._get_tool_return(next_step_action)
            if tool_return is not None:
                return self._return(
                    tool_return, intermediate_steps, run_manager=run_manager
                )
        iterations += 1
        time_elapsed = time.time() - start_time
    output = self.agent.return_stopped_response(
        self.early_stopping_method, intermediate_steps, **inputs
    )
    return self._return(output, intermediate_steps, run_manager=run_manager)

ようやくメインの部分にたどり着いた感じです。大まかに見ると以下の3ブロックと考えられます。

  1. 前準備:toolsの辞書作成やcallback、iter数などの設定。
  2. メイン:agentのloop
  3. 例外処理:loopから外れた時の処理

基本的にはself._take_next_step()で生成して、その型を確認。
もしFinishだったのなら出力を返して終了し、Actionだったなら、intermediate_stepsに途中経過として生成した内容を追加し、loop処理の先頭へ。
指定したiter数や経過時間を超えた場合は、3へ移行。その時点の結果を出力するような感じ?
(if len(next_step_output)の部分はaction_inputが生成されない特殊パターンへの対応?)
次はself._take_next_step()の詳細を確認。

生成部分の詳細

self._take_next_step()
def _take_next_step(
    self,
    name_to_tool_map: Dict[str, BaseTool],
    color_mapping: Dict[str, str],
    inputs: Dict[str, str],
    intermediate_steps: List[Tuple[AgentAction, str]],
    run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
    """Take a single step in the thought-action-observation loop.
    Override this to take control of how the agent makes and acts on choices.
    """
    try:
        # Call the LLM to see what to do.
        output = self.agent.plan(
            intermediate_steps,
            callbacks=run_manager.get_child() if run_manager else None,
            **inputs,
        )
    #...省略
    # If the tool chosen is the finishing tool, then we end and return.
    if isinstance(output, AgentFinish):
        return output
    actions: List[AgentAction]
    if isinstance(output, AgentAction):
        actions = [output]
    else:
        actions = output
    result = []
    for agent_action in actions:
        if run_manager:
            run_manager.on_agent_action(agent_action, color="green")
        # Otherwise we lookup the tool
        if agent_action.tool in name_to_tool_map:
            tool = name_to_tool_map[agent_action.tool]
            return_direct = tool.return_direct
            color = color_mapping[agent_action.tool]
            tool_run_kwargs = self.agent.tool_run_logging_kwargs()
            if return_direct:
                tool_run_kwargs["llm_prefix"] = ""
            # We then call the tool on the tool input to get an observation
            observation = tool.run(
                agent_action.tool_input,
                verbose=self.verbose,
                color=color,
                callbacks=run_manager.get_child() if run_manager else None,
                **tool_run_kwargs,
            )
        else:
            #...省略
        result.append((agent_action, observation))
    return result

まず、self.agent.plan()の部分について。
self.agentは最初のinitialize_agent()の部分で定義されている。(今回の場合はReActDocstoreAgentクラス)
最終的に呼び出されているのは子クラスにあるAgent.plan()で以下の通りです。

Agent.plan()
def plan(
    self,
    intermediate_steps: List[Tuple[AgentAction, str]],
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
    """Given input, decided what to do.
    Args:
        intermediate_steps: Steps the LLM has taken to date, along with observations
        callbacks: Callbacks to run.
        **kwargs: User inputs.
    Returns:
        Action specifying what tool to use.
    """
    full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
    full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
    return self.output_parser.parse(full_output)

最初のself.get_full_inputsintermediate_stepsから過去の履歴も含めた入力を作成。
次にself.llm_chain.predict()では最終的に以下の関数(LLMChain._call())を実行していました。

LLMChain._call()とgenerate()
def _call(
    self,
    inputs: Dict[str, Any],
    run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
    response = self.generate([inputs], run_manager=run_manager)
    return self.create_outputs(response)[0]

def generate(
    self,
    input_list: List[Dict[str, Any]],
    run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
    """Generate LLM result from inputs."""
    prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
    return self.llm.generate_prompt(
        prompts, stop, callbacks=run_manager.get_child() if run_manager else None
    )

やっとここでLLMの生成部分に到達…
self.prep_prompts()でagentに応じたpromptを生成し、self.llm.generate_prompt()でLLMからの出力を得ています。
(引数にあるstopはソースを見る限り、生成時のstop wordみたいな感じ?)

self.agent.plan()returnの部分では、agentごとに定義されたoutput_parserが使用されています。
具体的にはActionのprefix「Action: 」が生成文の最後にあるか確認し、「Action: 」以降の文章action_strを抽出。
今回のagnetでは「Action: Search[〇〇]」や「Action: Lookup[〇〇]」といった形でaction_strが抽出されるので、次の行動内容actionと入力内容action_inputを抽出。
actionがFinishかどうかで出力の内容を変えることで後の型判定に利用しています。
下記が今回のoutput_parserの中身。

output_parser
class ReActOutputParser(AgentOutputParser):
    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
        action_prefix = "Action: "
        if not text.strip().split("\n")[-1].startswith(action_prefix):
            raise OutputParserException(f"Could not parse LLM Output: {text}")
        action_block = text.strip().split("\n")[-1]

        action_str = action_block[len(action_prefix) :]
        # Parse out the action and the directive.
        re_matches = re.search(r"(.*?)\[(.*?)\]", action_str)
        if re_matches is None:
            raise OutputParserException(
                f"Could not parse action directive: {action_str}"
            )
        action, action_input = re_matches.group(1), re_matches.group(2)
        if action == "Finish":
            return AgentFinish({"output": action_input}, text)
        else:
            return AgentAction(action, action_input, text)

以上がself._take_next_step()の前半部分。
この後はif isinstanceで生成した内容がActionなのか、Finishなのかを確認し、Finishならそのまま出力を返して、Actionなら更に処理を進めてToolを利用する部分となります。

Tool利用部分

Toolを使ってActionを行っている部分は以下の通り。

self._take_next_step() 抜粋
        # Otherwise we lookup the tool
        if agent_action.tool in name_to_tool_map:
            tool = name_to_tool_map[agent_action.tool]
            return_direct = tool.return_direct
            color = color_mapping[agent_action.tool]
            tool_run_kwargs = self.agent.tool_run_logging_kwargs()
            if return_direct:
                tool_run_kwargs["llm_prefix"] = ""
            # We then call the tool on the tool input to get an observation
            observation = tool.run(
                agent_action.tool_input,
                verbose=self.verbose,
                color=color,
                callbacks=run_manager.get_child() if run_manager else None,
                **tool_run_kwargs,
            )

正規表現で抽出したactionname_to_tool_map(1番上のagent使用例で定義したtoolsnameのlist)に入っているか確認し、入っていればtool.run()observationを得ています。(こっちは1番上で定義したtoolsfuncを使用)
tool.run()の部分は、最初に定義したtoolfuncを動かしています。
今回で言えばdocstore=DocstoreExplorer(Wikipedia())なので、以下のような動きとなっていました。

  • search:pythonのwikipediaライブラリを使って、入力を元にした検索を実施。
    (中身を見ていると修正次第でURLも出力できそう?)
  • lookup:指定単語のある文章を検索(ctrl+Fのイメージ)

ここの部分を工夫することで、色んな作業をLLMが出来るようになっているんですかね?
以下に個人的に気になった記事を残します。

最後に

かなり雑なメモになっており、表現や内容の間違いもあるかもしれません。
気になる点がありましたら、ご指摘の程よろしくお願い致します。

8
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
11