1
0

もっといいやり方がある気もしますが。。。

導入

以下の記事でMLflow Tracingを簡単に試してみました。

特にLangChain処理のトレースはAuto Loggingによって簡単に実現できました。

それ以外のフレームワークではどうでしょうか?
最近だと私はLangGraphで処理を実装することが多く、LangGraphのノードや条件エッジ、またその内部のLangChain処理が記録されると便利です。

というわけで、LangGraphだとどのようにMLflow Tracingを使うのが良さそうか、探ってみました。
これがベストかは自信がありませんが、例ということで。

試験環境はDatabricks on AWS、DBRは15.2MLです。

Step1. パッケージインストール

MLflowとLangChain/LangGraphの最新パッケージをインストール。

%pip install -U langchain==0.2.5 langchain-community==0.2.5 mlflow-skinny[databricks]==2.14.1 langgraph==0.1.1
dbutils.library.restartPython()

Step2. Auto Loggingの無効化/トレーシングの有効化

最初にmlflow.langchainのauto loggingを無効化します。
有効化するとLangChainの処理を自動トレーシングしてくれるようになるのですが、LangGraphとの併用で問題が起きたので使用を断念。
そのため、以降の記録はマニュアル指定で行います。

import mlflow

# LangChain Autologgingは無効化
mlflow.langchain.autolog(disable=True)

# MLflow Tracingを有効化。無効化する場合はdisable()を呼び出すこと。
mlflow.tracing.enable()
# mlflow.tracing.disable()

Step3. グラフの作成

サンプルとして、こちらのチュートリアルの内容を改造して、簡単なグラフ処理を定義します。
MLflow Tracingを使うためのポイントは以下2カ所です。

  • 各ノードにMLflow fluent APIである@mlflow.traceデコレータを設定
  • LangChainのinvokeでコールバックとしてMlflowLangchainTracerのインスタンスを設定

これによって、ノード処理とLangChainの処理の両方を適切な粒度で記録できます。

from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.output_parsers import StrOutputParser

from mlflow.langchain.langchain_tracer import MlflowLangchainTracer


# グラフの状態
class State(TypedDict):
    messages: list
    output: str

# LangChainの簡易Chainを作成
# LLMはDatabricks Model Servingのモデルを利用しています。ここは好きなLLMでOK。
llm = ChatDatabricks(
    target_uri="databricks",
    endpoint="mistral-7b-instruct-v03-endpoint",
    temperature=0.1,
)
chain = llm | StrOutputParser()

# グラフノードの定義

@mlflow.trace(span_type="node")
def set_instruction(state: State):
    sys_prompt = [
        ("user", "あなたは優秀なAIアシスタントです。指示に的確に回答してください。"),
        ("assistant", "わかりました!"),
    ]
    return {"messages": sys_prompt + [state["messages"]]}


@mlflow.trace(span_type="node")
def chatbot(state: State):
    return {
        "output": [
            chain.invoke(
                state["messages"],
                config={
                    "callbacks": [MlflowLangchainTracer()]
                },  # CallbackにMlflowLangchainTracerを仕込むことで記録
            )
        ]
    }

# LangGraphによるグラフ構築
graph_builder = StateGraph(State)
graph_builder.add_node("set_instruction", set_instruction)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge("set_instruction", "chatbot")
graph_builder.set_entry_point("set_instruction")
graph_builder.set_finish_point("chatbot")
graph = graph_builder.compile()

グラフ処理としては以下のような流れになります。

from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    pass

image.png

Step4. グラフの実行/トレーシング

では作成したグラフを実行し、処理が記録されるか確認してみます。

ここのポイントは、mlflow.start_spanを使ってグラフの実行を囲むことで、処理全体をトレース対象としていることです。
(さらにmlflow.start_runで囲むことで、トレース結果をMLFlow Run単位で管理できます。ただし、必須ではありません)

# グラフの実行
with mlflow.start_run():
    user_input = "こんにちは!"
    with mlflow.start_span(
        "graph",
        span_type="AGENT",
        attributes={"user_input": user_input},
    ) as graph_span:
        result = graph.invoke({"messages": ("user", user_input)})
        print(result)

実行結果は以下のようになります。
graphを最上位のスパンとして、その下に各ノードの実行結果、そしてノード内のChain処理が記録されました。

image.png

ノードを選択すると、ノードのインプットとアウトプットが記録されていることも容易に確認できます。

image.png

LangChainの個々の処理も確認できます。

image.png

うーん、便利。
トレースのレイテンシも体感かなり低く、非常に使いやすい印象です。

まとめ

LangGraphとMLflow Tracingを組み合わせてトレースしてみました。
当初、Auto Loggingとの両立を模索していたのですがうまくいかず、このような形となりました。

とはいえ、ノードは@mlflow.traceで修飾するだけ、LangChain側もコールバックを設定するだけなので、そこまで難しいわけではないかと思います。

LangGraphを使う際は各ノードがどのように動作したかを把握することが非常に大事です。
MLflow Tracingと組み合わせて使うことで開発・運用共に効率的になることが期待できます。

願わくば、LangGraphの公式Auto Loggingが実装されると嬉しいなあ。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0