導入
LangGraphのバージョンが気づいたら0.1
となっていました。
(正確には、この記事執筆時点で0.1.4
がリリースされています)
日本語まとめはこちらの記事を参照ください。
LangGraphはエージェント処理を記述する点で非常に優れたフレームワークだと感じており、個人的にRAGを構成する際によく利用しています。(例えば、下記の記事)
こういったLangGraphのグラフを、さらにAPIエンドポイントとして公開することができると運用においてさらに便利です。
ただ、今までやってみたことがなかったので、Databricksモデルサービングを使ってAPIエンドポイントをデプロイしてみよう、というのがこの記事の主旨です。
Step1. パッケージインストール
LangGraphやLangChainなど、必要なパッケージをインストール。
%pip install -U langgraph==0.1.4 langchain==0.2.6 langchain-community==0.2.6 mlflow-skinny[databricks]==2.14.1 pydantic==2.7.4
dbutils.library.restartPython()
Step2. デプロイ用のグラフを作成
こちらの記事で作成したグラフをほとんど再利用して、試験用のグラフを作成します。
グラフ処理としてはシンプルで、ユーザからの入力に対してLLMからの回答を返すだけです。
(ノードを複数にしたかったので、最初に単純な応答を追加する処理を加えるようにしています)
前記事との違いは、State
クラスをPydanticを使ったものに変更したぐらい。
また、処理中で使っているLLMのエンドポイントはこちらの記事で作成したものを流用しています。
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.output_parsers import StrOutputParser
from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
from langchain_core.messages import ChatMessage
import operator
from typing import Annotated, Sequence, Optional
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel
# グラフの状態(Pydanticを使って構築)
class State(BaseModel):
input: str
messages: Optional[Sequence[BaseMessage]] = None
output: Optional[Sequence[str]] = None
# LangChainでLLMの結果を文字列で返す単純なChainを作成
llm = ChatDatabricks(
target_uri="databricks",
endpoint="mistral-7b-instruct-v03-endpoint",
temperature=0.1,
)
chain = llm | StrOutputParser()
# グラフノードの定義
@mlflow.trace(span_type="node")
def init_instruction(state: State):
"""デフォルトのインストラクションを仕込む"""
sys_prompt = [
ChatMessage(role="user", content="あなたは優秀なAIアシスタントです。指示に的確に回答してください。"),
ChatMessage(role="assistant", content="わかりました!"),
]
return {"messages": sys_prompt + [ChatMessage(role="user", content=state.input)]}
@mlflow.trace(span_type="node")
def chatbot(state: State):
"""Chatbotの推論結果を返す"""
return {
"output": [
chain.invoke(
state.messages,
config={
"callbacks": [MlflowLangchainTracer()]
}, # CallbackにMlflowLangchainTracerを仕込むことで記録
)
]
}
# LangGraphによるグラフ構築
graph_builder = StateGraph(State)
graph_builder.add_node("init_instruction", init_instruction)
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge("init_instruction", "chatbot")
graph_builder.set_entry_point("init_instruction")
graph_builder.set_finish_point("chatbot")
Step3. MLflow Pyfuncカスタムモデルを作成
ここからがポイントです。
現状、LangGraph用のフレーバーがMLflowにはないため、Pyfuncカスタムモデルを作成し、MLflowに登録します。
処理としては、LangGraphのビルダーをインスタンス生成時に渡しておき、load_context
でグラフをビルドするようにしています。
※ ビルド済みのグラフを最初から渡すと、正しく記録・ロードできませんでした。
import pandas as pd
class LangGraphModel(mlflow.pyfunc.PythonModel):
def __init__(self, graph_builder):
self.graph_builder = graph_builder
self.graph = None
def load_context(self, context):
# グラフを作成
self.graph = self.graph_builder.compile()
def predict(self, context, model_input, params=None):
# MLflow Tracingを有効化
with mlflow.start_span("graph", span_type="AGENT") as graph_span:
# グラフへの入力を取得
user_input = model_input["input"][0]
# mlflow tracingに入力を記録
graph_span.set_inputs({"input": user_input})
# グラフを実行
result = self.graph.invoke({"input": user_input})
result_state = State(**result).dict(exclude_unset=True)
# mlflow tracingに結果を記録
graph_span.set_outputs({"output_state": result_state})
return {"output_state": result_state}
Step4. MLflowへロギング
作成したカスタムモデルを使ってMLflowへ登録します。
まず、Signatureと入力サンプルを作成。
今回はinfer_signature
を使って、入力/出力サンプルを基に生成します。
from mlflow.models.signature import infer_signature
model_input = {"input": "Hello!"}
model_output = {
"output_state": {
"input": "こんにちは!",
"messages": [
{"content": "あなたは優秀なAIアシスタントです。指示に的確に回答してください。", "role": "user"},
{"content": "わかりました!", "role": "assistant"},
{"content": "こんにちは!", "role": "user"},
],
"output": [" こんにちは!私はAIであり、人間と同じように話すことができます。どうぞよろしくお願いします。"],
}
}
signature = infer_signature(model_input, model_output)
input_example = model_input
print("--- Signature ---")
print(signature)
print("--- Input Example ---")
print(input_example)
--- Signature ---
inputs:
['input': string (required)]
outputs:
['output_state': {input: string (required), messages: Array({content: string (required), role: string (required)}) (required), output: Array(string) (required)} (required)]
params:
None
--- Input Example ---
{'input': 'Hello!'}
次に準備したSignature等を使って、MLflowにロギング。
このとき、必要な依存関係も指定します。
合わせて、モデルサービングでデプロイできるようにUnity Catalogへモデルを登録しておきます。
(今回はtraining.langgraph.test_graph
へ登録)
import mlflow
mlflow.set_registry_uri("databricks-uc")
# ロギングするPyfuncカスタムモデルを作成
# 初期化時にLangGraphのグラフビルダーを渡している
model = LangGraphModel(graph_builder)
# Unity Catalogでのモデル登録場所
registered_model_name = "training.langgraph.test_graph"
# 追加の依存関係
extra_pip_requirements = [
"langgraph==0.1.4",
"langchain==0.2.6",
"langchain-community==0.2.6",
"pydantic==2.7.4",
]
# MLflowでのロギング
with mlflow.start_run() as run:
logged_model = mlflow.pyfunc.log_model(
artifact_path="model",
python_model=model,
extra_pip_requirements=extra_pip_requirements,
signature=signature,
input_example=input_example,
example_no_conversion=True,
registered_model_name=registered_model_name,
)
動作確認。
m = mlflow.pyfunc.load_model(logged_model.model_uri)
m.predict({"input": "Hello"})
{'output_state': {'input': 'Hello',
'messages': [{'content': 'あなたは優秀なAIアシスタントです。指示に的確に回答してください。',
'role': 'user'},
{'content': 'わかりました!', 'role': 'assistant'},
{'content': 'Hello', 'role': 'user'}],
'output': [' こんにちは!どうぞよろしくお願いします。']}}
動作しますね。
Step5. モデルサービングにデプロイ
MLflowへロギングしたモデルを利用して、Databricksモデルサービングにデプロイします。
今回はMLflow Deployments SDKを利用してエンドポイントを作成しました。
詳細は以下の公式ドキュメントを確認ください。
また、エンドポイント内から他のモデルサービングエンドポイントを呼び出せるように、環境変数としてDATABRICKS_HOST
とDATABRICKS_TOKEN
をシークレットから設定しています。
参考にする場合は事前にシークレットへホスト情報やAPIトークン情報を格納しておいてください。
from mlflow.deployments import get_deploy_client
client = get_deploy_client("databricks")
endpoint_name = "langgraph-endpoint"
base_name = "test-graph"
model_version = "1" # モデルのバージョンに合わせて要変更
endpoint = client.create_endpoint(
name=endpoint_name,
config={
"served_entities": [
{
"name": f"{base_name}-{model_version}",
"entity_name": registered_model_name,
"entity_version": model_version,
"workload_size": "Small",
"workload_type": "CPU",
"scale_to_zero_enabled": True,
"environment_vars": {
"DATABRICKS_HOST": "{{secrets/llm-serve/host}}",
"DATABRICKS_TOKEN": "{{secrets/llm-serve/api_token}}",
},
}
],
"traffic_config": {
"routes": [
{
"served_model_name": f"{base_name}-{model_version}",
"traffic_percentage": 100,
}
]
},
},
)
一定時間後、エンドポイントが作成されます。
作成後、次の処理で実際にAPIを実行してみましょう。
client = get_deploy_client("databricks")
response = client.predict(
endpoint=endpoint_name,
inputs={"input": "こんにちは!"},
)
response
{'output_state': {'input': 'こんにちは!',
'messages': [{'content': 'あなたは優秀なAIアシスタントです。指示に的確に回答してください。',
'role': 'user'},
{'content': 'わかりました!', 'role': 'assistant'},
{'content': 'こんにちは!', 'role': 'user'}],
'output': [' こんにちは!私はAIであり、人間と同じように話すことができます。どうぞよろしくお願いします。']}}
正常に実行できました。
まとめ
LangGraphで作成したグラフをMLflowでロギングし、Databricksモデルサービングを使ってAPIエンドポイントをデプロイしました。
これで作成したエージェント処理などをフロントエンドから簡単に利用することができます。
現状DatabricksモデルサービングエンドポイントはStreamingに未対応(たぶん)ですし、ユーザビリティの観点だと少し使いづらいところはありますが、エンドポイントのバージョン管理など、本番運用においては重要な機能が既に揃っています。
また、LangGraphは通常のLangChainを利用する際に比べて、個人的にかなり使いやすいフレームワークだと思います。
今後、Databricksで使う上でMosaic AI Agent Frameworkとの連携をどうできるかよくわかっていないのですが、うまくアーキテクチャを組めたらと思っています。