こんにちは、dev_catです。
前回は LangMem + LangGraph を使って、GPT-4oと一緒に“記憶するAI”を作ってみたのですが…
今回はさらに、「賢くなるAI」 を作ってみました!
🎯 この記事でやること
- LangMem の
prompt_optimizer
を使って、自己改善するAIを作る - 過去の会話 から、プロンプトを進化させる
- 実際に FastAPI + GPT-4o + LangMem で動かすAPIを公開!
🌱 なぜプロンプト最適化?
通常、AIチャットの「キャラ設定」は system プロンプトで固定されがち。
でも…
✨ユーザーとの会話をもとに、AIが自分をアップデートしてくれたら?
それを実現できるのが prompt_optimizer
。
特に今回は、Gradientモード を使って「改善を積み重ねる」ような動きを実装しました。
🛠 使用ライブラリ
-
langmem
- メモリ&プロンプト最適化 -
openai
- GPT-4o API -
langgraph
- ストアとしてPostgreSQL活用 -
FastAPI
- エンドポイント構築
🧠 コア部分:プロンプト最適化の実装
from langmem import create_prompt_optimizer
optimizer = create_prompt_optimizer(
"openai:gpt-4o",
kind="gradient",
config={
"max_reflection_steps": 8,
"min_reflection_steps": 5
}
)
🤖 過去会話とフィードバックを渡すだけ!
if past_conversation:
trajectories = [
(past_conversation, {
"feedback": "Make your tone casual and playful, like you're talking to a close friend. Avoid sounding too formal."
})
]
better_prompt = await optimizer.invoke({
"trajectories": trajectories,
"prompt": base_prompt
})
else:
better_prompt = base_prompt
✅ 得られた better_prompt を system メッセージとして使う!
messages = [
{"role": "system", "content": better_prompt},
*past_conversation,
{"role": "user", "content": user_input}
]
🔁 一連の処理の流れ
- DBからユーザーの過去会話を取得(LangMem + LangGraph)
- それを元にプロンプト最適化(Gradientモード)
- GPT-4o に渡して返答生成(必要なら LangMemのツール呼び出しも)
- 応答とユーザー発言をまたメモリに保存
🧪 補足:LangMemの Gradient モードとは?
- 反映(Reflection) を最大8ステップまで繰り返して改善
- 過去会話の「フィードバック」をもとに、systemプロンプトを最適化
- 回数やステップ数を調整することで柔軟なプロンプト制御が可能
📌 まとめ
この prompt_optimizer
を組み込むだけで、
AIが自分の“話し方”を過去の会話を通じて学んでくれるようになります!
ちょっとした工夫で、キャラが自然に成長していく体験が作れるのが楽しいです 🧠💬
🔗 もし動かしてみたい時は!
こちらの記事に沿って、環境構築をしていただいて、
https://qiita.com/dev-cat/items/59b01885416fae92c477
以下をrooterに追加し、main.pyでエンドポイント設定してもらえれば動きます!
chat_optimize_gradient.py
from typing import Any, Dict, List
from langgraph.store.memory import InMemoryStore
from openai import OpenAI
from langmem import create_manage_memory_tool, create_search_memory_tool
from pydantic import BaseModel
from fastapi import APIRouter, FastAPI
from langgraph.store.postgres import AsyncPostgresStore
from langmem import create_prompt_optimizer
import asyncio
import os
import json
import uuid
import time
optimizer = create_prompt_optimizer(
"openai:gpt-4o",
kind="gradient", # gradient モードでプロンプト改善
config={
"max_reflection_steps": 8,
"min_reflection_steps": 5
}
)
router = APIRouter()
store = None
store_cm = None
async def init_store():
global store
global store_cm
POSTGRES_URL = os.getenv("POSTGRES_URL")
max_attempts = 5
delay_sec = 2
for attempt in range(max_attempts):
try:
print("Initializing store...")
print(f"Try {attempt + 1}/{max_attempts}")
store_cm = AsyncPostgresStore.from_conn_string(POSTGRES_URL)
store = await store_cm.__aenter__()
await store.setup()
print("Store initialized!")
return
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt < max_attempts - 1:
print(f"Retrying in {delay_sec} seconds...")
await asyncio.sleep(delay_sec)
else:
print("All attempts failed.")
@router.on_event("startup")
async def startup_event():
await init_store()
class ChatRequest(BaseModel):
user_id: str
user_message: str
@router.post("/chat_optimize_gradient")
async def chat(request: ChatRequest):
global store
if store is None:
return {"error": "Store is not initialized"}
user_id = request.user_id
user_message = request.user_message
namespace = ("memories", user_id)
# ツール生成
memory_tools = [
create_manage_memory_tool(namespace=namespace, store=store),
create_search_memory_tool(namespace=namespace, store=store),
]
# 過去会話を取得
try:
user_memories = await store.asearch(namespace, limit=1000)
print("user_memories取得成功:", user_memories)
except Exception as e:
print("user_memories取得失敗:", e)
return {"error": f"user_memories取得失敗: {str(e)}"}
if user_memories:
print("user_memories[0] の中身:", user_memories[0])
else:
print("user_memories は空")
past_conversation = []
for memory in user_memories:
try:
role = memory.value.get("role")
content = memory.value.get("content")
timestamp = memory.value.get("timestamp", 0)
print(f"memory.value: role={role}, content={repr(content)}, timestamp={timestamp}")
if role in ["user", "assistant"] and content:
past_conversation.append({
"role": role,
"content": content,
"timestamp": timestamp
})
else:
print("スキップ対象のメモリ(不正なroleまたは空のcontent)")
except Exception as e:
print("memory.valueの処理中に例外発生:", e)
# past_conversation = sorted(past_conversation, key=lambda x: x["timestamp"])[-3:]
past_conversation = sorted(past_conversation, key=lambda x: x["timestamp"])
# プロンプト最適化付きチャット応答
result = await run_agent(
tools=memory_tools,
user_input=user_message,
past_conversation=past_conversation,
)
# ユーザー発言と応答を記録
await store.aput(
namespace=("memories", user_id),
key=str(uuid.uuid4()), # もしくは任意の一意なキー
value={
"role": "user",
"content": user_message,
"timestamp": time.time()
}
)
await store.aput(
namespace=("memories", user_id),
key=str(uuid.uuid4()),
value={
"role": "assistant",
"content": result,
"timestamp": time.time()
}
)
# 直近メモリを返す
new_memories = await store.asearch(namespace)
serialized_memories = [m.dict() for m in new_memories]
return {
"message": result,
"memories": serialized_memories
}
async def execute_tool(tools_by_name: Dict[str, Any], tool_call: Dict[str, Any]) -> str:
tool_name = tool_call["function"]["name"]
print(f"Trying to execute tool: {tool_name}")
if tool_name not in tools_by_name:
return f"Error: Tool {tool_name} not found"
tool = tools_by_name[tool_name]
try:
args = tool_call["function"]["arguments"]
print(f"Tool args: {args}")
if asyncio.iscoroutinefunction(tool.invoke):
result = await tool.invoke(args)
else:
result = await asyncio.to_thread(tool.invoke, args)
print(f"Tool result: {result}")
return str(result)
except Exception as e:
print(f"Tool error: {e}")
return f"Error executing {tool_name}: {str(e)}"
async def run_agent(tools: List[Any], user_input: str, past_conversation: List[Dict[str, str]] = [], base_prompt: str = "You're an energetic and friendly AI assistant who speaks casually like a close friend, responds in a helpful and engaging way to any topic, and can throw in jokes or fun expressions to make conversations more lively.", max_steps: int = 5) -> str:
client = OpenAI()
tools_by_name = {tool.name: tool for tool in tools}
openai_tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.tool_call_schema.model_json_schema(),
},
}
for tool in tools
]
# プロンプト最適化
try:
if past_conversation:
trajectories = [(past_conversation, {"feedback": "Make your tone casual and playful, like you're talking to a close friend. Avoid sounding too formal."})]
better_prompt = await optimizer.invoke({
"trajectories": trajectories,
"prompt": base_prompt
})
else:
better_prompt = base_prompt
except Exception as e:
print(f"Prompt optimization failed: {e}")
better_prompt = base_prompt
messages = [
{"role": "system", "content": better_prompt},
*past_conversation,
{"role": "user", "content": user_input}
]
for step in range(max_steps):
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=openai_tools if step < max_steps - 1 else [],
tool_choice="auto",
temperature=0.8,
top_p=0.95
)
message = response.choices[0].message
tool_calls = message.tool_calls
if not tool_calls:
return message.content
messages.append(
{"role": "assistant", "content": message.content, "tool_calls": tool_calls}
)
for tool_call in tool_calls:
tool_result = await execute_tool(tools_by_name, tool_call.model_dump())
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
})
return "Reached maximum number of steps"
asyncio.create_task(init_store())
🙋♂️ コメント歓迎!
質問・指摘・アイデア大歓迎です!
いいね・ストックいただけたら最高に励みになります 🙌