もっといいやり方がある気もしますが。。。
導入
以下の記事でMLflow Tracingを簡単に試してみました。
特にLangChain処理のトレースはAuto Loggingによって簡単に実現できました。
それ以外のフレームワークではどうでしょうか?
最近だと私はLangGraphで処理を実装することが多く、LangGraphのノードや条件エッジ、またその内部のLangChain処理が記録されると便利です。
というわけで、LangGraphだとどのようにMLflow Tracingを使うのが良さそうか、探ってみました。
これがベストかは自信がありませんが、例ということで。
試験環境はDatabricks on AWS、DBRは15.2MLです。
Step1. パッケージインストール
MLflowとLangChain/LangGraphの最新パッケージをインストール。
%pip install -U langchain==0.2.5 langchain-community==0.2.5 mlflow-skinny[databricks]==2.14.1 langgraph==0.1.1
dbutils.library.restartPython()
Step2. Auto Loggingの無効化/トレーシングの有効化
最初にmlflow.langchain
のauto loggingを無効化します。
有効化するとLangChainの処理を自動トレーシングしてくれるようになるのですが、LangGraphとの併用で問題が起きたので使用を断念。
そのため、以降の記録はマニュアル指定で行います。
import mlflow
# LangChain Autologgingは無効化
mlflow.langchain.autolog(disable=True)
# MLflow Tracingを有効化。無効化する場合はdisable()を呼び出すこと。
mlflow.tracing.enable()
# mlflow.tracing.disable()
Step3. グラフの作成
サンプルとして、こちらのチュートリアルの内容を改造して、簡単なグラフ処理を定義します。
MLflow Tracingを使うためのポイントは以下2カ所です。
- 各ノードにMLflow fluent APIである
@mlflow.trace
デコレータを設定 - LangChainの
invoke
でコールバックとしてMlflowLangchainTracer
のインスタンスを設定
これによって、ノード処理とLangChainの処理の両方を適切な粒度で記録できます。
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.output_parsers import StrOutputParser
from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
# グラフの状態
class State(TypedDict):
messages: list
output: str
# LangChainの簡易Chainを作成
# LLMはDatabricks Model Servingのモデルを利用しています。ここは好きなLLMでOK。
llm = ChatDatabricks(
target_uri="databricks",
endpoint="mistral-7b-instruct-v03-endpoint",
temperature=0.1,
)
chain = llm | StrOutputParser()
# グラフノードの定義
@mlflow.trace(span_type="node")
def set_instruction(state: State):
sys_prompt = [
("user", "あなたは優秀なAIアシスタントです。指示に的確に回答してください。"),
("assistant", "わかりました!"),
]
return {"messages": sys_prompt + [state["messages"]]}
@mlflow.trace(span_type="node")
def chatbot(state: State):
return {
"output": [
chain.invoke(
state["messages"],
config={
"callbacks": [MlflowLangchainTracer()]
}, # CallbackにMlflowLangchainTracerを仕込むことで記録
)
]
}
# LangGraphによるグラフ構築
graph_builder = StateGraph(State)
graph_builder.add_node("set_instruction", set_instruction)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge("set_instruction", "chatbot")
graph_builder.set_entry_point("set_instruction")
graph_builder.set_finish_point("chatbot")
graph = graph_builder.compile()
グラフ処理としては以下のような流れになります。
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
pass
Step4. グラフの実行/トレーシング
では作成したグラフを実行し、処理が記録されるか確認してみます。
ここのポイントは、mlflow.start_span
を使ってグラフの実行を囲むことで、処理全体をトレース対象としていることです。
(さらにmlflow.start_run
で囲むことで、トレース結果をMLFlow Run単位で管理できます。ただし、必須ではありません)
# グラフの実行
with mlflow.start_run():
user_input = "こんにちは!"
with mlflow.start_span(
"graph",
span_type="AGENT",
attributes={"user_input": user_input},
) as graph_span:
result = graph.invoke({"messages": ("user", user_input)})
print(result)
実行結果は以下のようになります。
graph
を最上位のスパンとして、その下に各ノードの実行結果、そしてノード内のChain処理が記録されました。
ノードを選択すると、ノードのインプットとアウトプットが記録されていることも容易に確認できます。
LangChainの個々の処理も確認できます。
うーん、便利。
トレースのレイテンシも体感かなり低く、非常に使いやすい印象です。
まとめ
LangGraphとMLflow Tracingを組み合わせてトレースしてみました。
当初、Auto Loggingとの両立を模索していたのですがうまくいかず、このような形となりました。
とはいえ、ノードは@mlflow.trace
で修飾するだけ、LangChain側もコールバックを設定するだけなので、そこまで難しいわけではないかと思います。
LangGraphを使う際は各ノードがどのように動作したかを把握することが非常に大事です。
MLflow Tracingと組み合わせて使うことで開発・運用共に効率的になることが期待できます。
願わくば、LangGraphの公式Auto Loggingが実装されると嬉しいなあ。