導入
最近、生成AI(LLM)を使った実務的なAIエージェントって何だろうとボンヤリ考えたりします。
以下に代表されるDatabricksのドキュメントにはAIエージェントの例も記載されていますが、アイディアとしてはもっと色々なものがあるはず。
ただ、自分自身でもっと作ってみないとイメージも湧きづらいなと思ったので、以下のGithubリポジトリで紹介されているGenAI AgentのチュートリアルをDatabricks上でウォークスルーしながら学んでいこうと思います。
そのまま動かすのも芸がないので、できるだけDatabricksで再利用可能な形にエージェントを実装したらどうなるか?という形に修正してみます。
また本来はMosaic AI Agent Frameworkをフル活用するべきですが日本リージョンでは利用できない機能がまだ多いため、そのあたりは避けてやっていこうと思います。
今回は最初のチュートリアルであるSimple Conversational Agentです。
Simple Conversational Agent概要
ノートブックの上段を邦訳して引用。
コンテキスト認識を持つ会話エージェントの構築
概要
このチュートリアルでは、複数の対話にわたってコンテキストを維持する会話エージェントを作成するプロセスを概説します。現代のAIフレームワークを使用して、より自然で一貫性のある会話を行うエージェントを構築します。
動機
多くのシンプルなチャットボットはコンテキストを維持する能力が欠けており、断片的でフラストレーションのたまるユーザー体験を引き起こします。このチュートリアルは、会話の前の部分を覚えて参照できる会話エージェントを実装することで、その問題を解決することを目指しています。これにより、全体的な対話の質が向上します。
主要コンポーネント
- 言語モデル: 応答を生成するコアAIコンポーネント。
- プロンプトテンプレート: 会話の構造を定義します。
- 履歴マネージャー: 会話の履歴とコンテキストを管理します。
- メッセージストア: 各会話セッションのメッセージを保存します。
メソッドの詳細
環境の設定
必要なAIフレームワークを設定し、適切な言語モデルにアクセスできるようにします。これが会話エージェントの基盤となります。
チャット履歴ストアの作成
複数の会話セッションを管理するシステムを実装します。各セッションは一意に識別可能で、それぞれのメッセージ履歴に関連付けられます。
会話構造の定義
以下を含むテンプレートを作成します:
- AIの役割を定義するシステムメッセージ
- 会話履歴のプレースホルダー
- ユーザーの入力
この構造はAIの応答を導き、会話全体の一貫性を維持します。会話チェーンの構築
プロンプトテンプレートと言語モデルを組み合わせて基本的な会話チェーンを作成します。このチェーンを履歴管理コンポーネントでラップし、会話履歴の挿入と取得を自動的に処理します。
エージェントとの対話
エージェントを使用するには、ユーザー入力とセッション識別子を使用して呼び出します。履歴マネージャーは適切な会話履歴を取得し、プロンプトに挿入し、各対話後に新しいメッセージを保存します。
結論
この会話エージェントのアプローチには以下の利点があります:
- コンテキスト認識: エージェントは会話の前の部分を参照できるため、より自然な対話が可能です。
- シンプルさ: モジュラー設計により実装が簡単です。
- 柔軟性: 会話構造の変更や異なる言語モデルへの切り替えが容易です。
- スケーラビリティ: セッションベースのアプローチにより、複数の独立した会話を管理できます。
この基盤をもとに、さらにエージェントを強化することができます:
- より高度なプロンプトエンジニアリングの実装
- 外部知識ベースとの統合
- 特定のドメインに特化した機能の追加
- エラーハンドリングと会話修復戦略の組み込み
コンテキスト管理に焦点を当てることで、この会話エージェントの設計は基本的なチャットボット機能を大幅に改善し、より魅力的で役立つAIアシスタントへの道を開きます。
LangChainのチュートリアルでもよく見られる会話履歴を保持したチャットの仕組となります。
エージェント?というところもあると思いますが、まずはここからスタートしていきます。
実装と実行
では、Databricks上で実装していきます。
ノートブックを作成してパッケージをインストール。
主にLangChain関連とMlflow最新版をインストールします。
%pip install -q -U langchain-core==0.3.13 langchain-databricks==0.1.1 langchain_community==0.3.3
%pip install -q -U typing-extensions
%pip install -q -U "mlflow-skinny[databricks]==2.17.1"
dbutils.library.restartPython()
MLflowのカスタムチャットモデルとして、エージェントを実装します。
チャットモデル固有の処理が多くを占めており、今回の会話エージェントの処理は主に_chat
メソッドの部分です。
ここでLangChainを使った会話履歴管理付の応答処理を実装しています。
%%writefile "./simple_conversational_agent.py"
import uuid
from typing import List, Optional, Dict
import mlflow
from mlflow.pyfunc import ChatModel
from mlflow.models import set_model
from mlflow.types.llm import (
ChatResponse,
ChatMessage,
ChatParams,
ChatChoice,
)
from langchain_databricks import ChatDatabricks
from langchain.memory import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
class SimpleConversationalAgent(ChatModel):
"""
LangChainとMLflowを使用してチャット履歴を管理し、予測を行うシンプルな会話エージェント。
属性:
models (dict): モデル構成を格納する辞書。
models_config (dict): 追加のモデル構成を格納する辞書。
store (dict): チャット履歴を格納する辞書。
"""
def __init__(self):
"""
空のmodels、models_config、およびstoreでSimpleConversationalAgentを初期化します。
"""
self.models = {}
self.models_config = {}
self.store = {}
def load_context(self, context):
"""
エージェントのコンテキストをロードし、モデル構成を設定します。
引数:
context: モデル構成を含むコンテキスト。
"""
self.models = context.model_config.get("models", {})
def predict(
self, context, messages: List[ChatMessage], params: Optional[ChatParams] = None
) -> ChatResponse:
"""
指定されたチャットメッセージとパラメータに対する応答を予測します。
引数:
context: 予測のためのコンテキスト。
messages (List[ChatMessage]): チャットメッセージのリスト。
params (Optional[ChatParams]): チャットのためのオプションのパラメータ。
戻り値:
ChatResponse: チャットモデルからの応答。
"""
with mlflow.start_span(name="Agent Prediction") as root_span:
root_span.set_inputs(messages)
attributes = {**params.to_dict(), **self.models_config, **self.models}
root_span.set_attributes(attributes)
endpoint = self._get_model_endpoint("agent")
input = messages[-1].content
response = self._chat(input, endpoint, params.to_dict())
output = ChatResponse(
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content=response.content,
),
)
],
usage=response.response_metadata,
model=endpoint,
)
root_span.set_outputs(output)
return output
@mlflow.trace(name="Chat")
def _chat(self, input, endpoint: str, params: dict):
"""
モデルとのチャットインタラクションを処理します。
引数:
input: 入力メッセージの内容。
endpoint (str): モデルのエンドポイント。
params (dict): チャットのためのパラメータ。
戻り値:
チャットモデルからの応答。
"""
session_id = params.get("metadata", {}).get("session_id", str(uuid.uuid4()))
prompt = ChatPromptTemplate.from_messages(
[
("system", self._get_system_prompt("agent")),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
]
)
chat_model = ChatDatabricks(endpoint=endpoint, **params)
chain = prompt | chat_model
chain_with_history = RunnableWithMessageHistory(
chain,
self._get_chat_history,
input_messages_key="input",
history_messages_key="history",
)
return chain_with_history.invoke(
{"input": input},
config={"configurable": {"session_id": session_id}},
)
def _get_chat_history(self, session_id: str):
"""
指定されたセッションIDのチャット履歴を取得します。
引数:
session_id (str): セッションID。
戻り値:
ChatMessageHistory: セッションのチャット履歴。
"""
if session_id not in self.store:
self.store[session_id] = ChatMessageHistory()
return self.store[session_id]
def _get_system_prompt(self, role: str) -> str:
"""
指定された役割のシステムプロンプトを取得します。
引数:
role (str): システムプロンプトを取得する役割。
戻り値:
str: システムプロンプト。
"""
role_config = self.models.get(role, {})
return role_config.get("instruction", "You are a helpful assistant.")
def _get_model_endpoint(self, role: str) -> str:
"""
指定された役割のモデルエンドポイントを取得します。
引数:
role (str): モデルエンドポイントを取得する役割。
戻り値:
str: モデルエンドポイント。
"""
role_config = self.models.get(role, {})
return role_config.get("endpoint")
set_model(SimpleConversationalAgent())
次にカスタムチャットモデルをMLflowに保管・登録します。
モデル設定として、LLMのエンドポイント名を指定していますが、Llama 3.2 3Bモデルを用いたMosaic AI Model Servingエンドポイントを指定しています。
Llama 3.2のエンドポイント作成については以下の記事がわかりやすいです。
import mlflow
# Databricks Unity Catalogを利用してモデル管理
mlflow.set_registry_uri("databricks-uc")
model_config = {
"models": {
"agent": {
"endpoint": "llama_v3_2_3b_instruct_endpoint",
"instruction": "You are a helpful AI assistant.",
},
},
}
input_example = {
"messages": [
{
"role": "user",
"content": "こんにちは!",
}
]
}
registered_model_name = "training.llm.simple_conversational_agent"
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
"model",
python_model="simple_conversational_agent.py",
model_config=model_config,
input_example=input_example,
registered_model_name=registered_model_name,
)
モデルが無事保管できたら、ロードして実際に使ってみます。
import mlflow
from mlflow import MlflowClient
# モデルをロード
client = MlflowClient()
versions = [
mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")
]
agent = mlflow.pyfunc.load_model(f"models:/{registered_model_name}/{versions[0]}")
# ロードしたモデルを使った簡易チャット関数
def chat(session_id, input):
result = agent.predict(
{
"messages": [
{
"role": "user",
"content": input,
}
],
"temperature": 1.0,
"max_tokens": 1000,
"metadata": {"session_id": session_id},
}
)
print("USER:", input)
print("AI:", result.get("choices")[0].get("message").get("content"))
# 実際に履歴付きでチャットしてみる
session_id = "001"
print(f"--- {session_id=} conversation ---")
chat(session_id, "Hello! How are you?")
chat(session_id, "What was my previous message?")
chat(session_id, "Pls infer my next message.")
print()
# session_idごとに会話履歴が管理されることを確認
session_id = "002"
print(f"--- {session_id=} conversation ---")
chat(session_id, "What was my previous message?")
--- session_id='001' conversation ---
USER: Hello! How are you?
AI: Hello! I'm doing well, thank you for asking! I'm a large language model, so I don't have feelings or emotions like humans do, but I'm always happy to help and assist with any questions or tasks you may have. How about you? How's your day going so far?
USER: What was my previous message?
AI: Your previous message was "Hello! How are you?"
USER: Pls infer my next message.
AI: Based on our conversation so far, I'm going to take a guess that your next message will likely be another question or inquiry related to our conversation, possibly something like "How are you doing?" or "Is there something I can help you with?"
--- session_id='002' conversation ---
USER: What was my previous message?
AI: This is the start of our conversation, so there is no previous message to refer to. What can I help you with today?
日本語に翻訳すると以下の通り。
--- session_id='001' 会話 ---
ユーザー: こんにちは!お元気ですか?
AI: こんにちは!私は元気です。ご質問ありがとうございます!私は大規模な言語モデルなので、人間のような感情や感覚はありませんが、いつでもお手伝いや質問にお答えすることができます。あなたはどうですか?今日はどんな一日を過ごしていますか?
ユーザー: 私の前のメッセージは何でしたか?
AI: あなたの前のメッセージは「こんにちは!お元気ですか?」でした。
ユーザー: 次のメッセージを推測してください。
AI: これまでの会話に基づいて、次のメッセージはおそらく会話に関連する別の質問や問い合わせになると思います。例えば、「あなたはどうですか?」や「何かお手伝いできることはありますか?」のようなものです。
--- session_id='002' 会話 ---
ユーザー: 私の前のメッセージは何でしたか?
AI: これは私たちの会話の始まりなので、参照する前のメッセージはありません。今日は何をお手伝いできますか?
履歴付きでLLMと会話できていますね。
まとめ
GenAI Agentsの最初のチュートリアルをMLflowのカスタムモデルを使うように魔改造して実行してみました。
これによってMLflowによるモデル(エージェント)管理やMosaic AI Model ServingによるDatabricks上でのデプロイを行うことができます。
ちなみに会話履歴をメモリ上に保持するチュートリアルとなっているため、エンドポイントにデプロイする際は注意が必要です。実際にはストレージなどへ永続化するような変更が必要でしょう。
今回は非常に簡単なものですが、自分の勉強のためにもエージェント周りのチュートリアルウォークスルーは継続していこうと思います。記事として続けていくかは・・・やる気次第。