はじめに
LangChain 1.0がリリースされました。
詳細は以下の公式Blogを確認ください。
主な変化点は以下のドキュメントにも掲載されています。
ドキュメント含めて大きく変更されており、またαリリースの際よりも組み込みMiddlewareも増えてましたので、個人的に気になる内容をピックアップしてDatabricks Free Edition上で検証していきます。
この記事では、以下の「エージェントメモリのカスタマイズ」を扱います。
エージェントは内部状態(短期記憶)を持ちますが、LangChainのcreate_agentで作成されるエージェントは標準だとメッセージの履歴のみを保持します。
一方、エージェント処理によっては異なる構造の情報を保持したいこともあります。
そのために、状態のスキーマ(保持する情報の定義)を変更できるようになっています。
方法としては、ドキュメントに記載があるように、以下の2種があるようです。
- ミドルウェアを使う方法
-
create_agentのパラメータとしてスキーマを指定する方法
前者が推奨されている方法のようですので、ミドルウェアを使ったスキーマ定義の変更を行ってみます。
エージェントメモリのカスタマイズ
まずはノートブックを作成して必要なパッケージをインストール。
TracingのためにMLflowもインストールします。
%pip install -U langchain>=1.0.0 langchain_openai>=1.0.0 mlflow
%restart_python
次に利用するモデルをセットアップします。
利用するモデルは任意でOKですが、今回はdatabricks-gpt-oss-20bを利用しました。
from langchain.chat_models import init_chat_model
import mlflow
mlflow.langchain.autolog()
creds = mlflow.utils.databricks_utils.get_databricks_host_creds()
model = init_chat_model(
"openai:databricks-gpt-oss-20b",
api_key=creds.token,
base_url=creds.host + "/serving-endpoints",
)
次にエージェントメモリ(状態)のスキーマを変更・利用するミドルウェアを作成します。
今回はサンプルに則って、モデルの呼び出し回数をカウントし、それを内部状態として保持するミドルウェアにしました。
変更の流れとしては、以下のようになります。
-
AgentStateをカスタムした状態用クラスを定義 - ミドルウェア内にて
state_schemaを指定
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.middleware import AgentState, AgentMiddleware
from typing_extensions import NotRequired
from typing import Any
# カスタム
class CustomState(AgentState):
model_call_count: NotRequired[int]
class CallCounterMiddleware(AgentMiddleware[CustomState]):
state_schema = CustomState # ここで状態スキーマを変更!
def before_model(self, state: CustomState, runtime) -> dict[str, Any] | None:
count = state.get("model_call_count", 0)
if count > 10:
return {"jump_to": "end"}
return None
def after_model(self, state: CustomState, runtime) -> dict[str, Any] | None:
return {"model_call_count": state.get("model_call_count", 0) + 1}
最後に上記ミドルウェアを組み込んだ単純なエージェントを作成・実行します。
from langchain.agents import create_agent
from langchain.agents.structured_output import ToolStrategy
from pprint import pprint
def get_weather(city: str) -> str:
"""指定した都市の天気を取得します。"""
return f"It's always sunny in {city}!"
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[CallCounterMiddleware()],
)
input = {
"messages": [{"role": "user", "content": "SFの天気は?"}]
}
agent.invoke(input)
実行結果は以下のようになりました。
{'messages': [HumanMessage(content='SFの天気は?', additional_kwargs={}, response_metadata={}, id='7ec28573-baae-4476-b309-98c3f09521a7'),
AIMessage(content='', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 71, 'prompt_tokens': 139, 'total_tokens': 210, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_provider': 'openai', 'model_name': 'gpt-oss-20b-080525', 'system_fingerprint': None, 'id': 'chatcmpl_c58ab0d4-de12-4a34-b86d-e1eae1126f24', 'finish_reason': 'tool_calls', 'logprobs': None}, id='lc_run--1d8da3d7-47a9-46cf-a0ef-d1d04a096eb4-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'San Francisco'}, 'id': 'call_cabe0267-35e4-445a-b03b-2ab54479359a', 'type': 'tool_call'}], usage_metadata={'input_tokens': 139, 'output_tokens': 71, 'total_tokens': 210, 'input_token_details': {}, 'output_token_details': {}}),
ToolMessage(content="It's always sunny in San Francisco!", name='get_weather', id='fd51ad12-c848-4bb1-b153-2c2f66d0da83', tool_call_id='call_cabe0267-35e4-445a-b03b-2ab54479359a'),
AIMessage(content='San\u202fFrancisco\u202f云はとても澄みやすいです。現在の天候は晴れで、日中は軽く風が吹く程度。黄昏前からちょっとした雲が広がる可能性もありますが、全体的に快晴が予想されます。2024年10月30日現在の情報です。', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 104, 'prompt_tokens': 178, 'total_tokens': 282, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_provider': 'openai', 'model_name': 'gpt-oss-20b-080525', 'system_fingerprint': None, 'id': 'chatcmpl_26148fd9-1d39-4eb2-a406-a013433ceb41', 'finish_reason': 'stop', 'logprobs': None}, id='lc_run--2c083bba-507e-48ad-b65a-fa02555329f0-0', usage_metadata={'input_tokens': 178, 'output_tokens': 104, 'total_tokens': 282, 'input_token_details': {}, 'output_token_details': {}})],
'model_call_count': 2}
出力内容にカスタムした状態のmodel_call_countが含まれているのがわかります。
MLflow Tracingの結果は以下になります。
設定したMiddlewareが呼ばれて、model_call_count状態が更新されていることがわかりますね。
このようにミドルウェアを通じてメモリの定義をカスタマイズすることができます。
まとめ
LangChainエージェントのエージェントメモリのカスタマイズする方法を試してみました。
もうミドルウェアが全てを解決する、という感じにカスタム系は全部ミドルウェアが推奨されていますね。
LangChain v1の最大の変化点はミドルウェアなのかも。
