0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LangMemのprompt_optimizerで、「賢くなるAI」を作ってみた

Posted at

こんにちは、dev_catです。
前回は LangMem + LangGraph を使って、GPT-4oと一緒に“記憶するAI”を作ってみたのですが…

今回はさらに、「賢くなるAI」 を作ってみました!

🎯 この記事でやること

  • LangMem の prompt_optimizer を使って、自己改善するAIを作る
  • 過去の会話 から、プロンプトを進化させる
  • 実際に FastAPI + GPT-4o + LangMem で動かすAPIを公開!

🌱 なぜプロンプト最適化?

通常、AIチャットの「キャラ設定」は system プロンプトで固定されがち。

でも…

✨ユーザーとの会話をもとに、AIが自分をアップデートしてくれたら?

それを実現できるのが prompt_optimizer
特に今回は、Gradientモード を使って「改善を積み重ねる」ような動きを実装しました。

🛠 使用ライブラリ

  • langmem - メモリ&プロンプト最適化
  • openai - GPT-4o API
  • langgraph - ストアとしてPostgreSQL活用
  • FastAPI - エンドポイント構築

🧠 コア部分:プロンプト最適化の実装

from langmem import create_prompt_optimizer

optimizer = create_prompt_optimizer(
    "openai:gpt-4o",
    kind="gradient",
    config={
        "max_reflection_steps": 8,
        "min_reflection_steps": 5
    }
)

🤖 過去会話とフィードバックを渡すだけ!

if past_conversation:
    trajectories = [
        (past_conversation, {
            "feedback": "Make your tone casual and playful, like you're talking to a close friend. Avoid sounding too formal."
        })
    ]
    better_prompt = await optimizer.invoke({
        "trajectories": trajectories,
        "prompt": base_prompt
    })
else:
    better_prompt = base_prompt

✅ 得られた better_prompt を system メッセージとして使う!

messages = [
    {"role": "system", "content": better_prompt},
    *past_conversation,
    {"role": "user", "content": user_input}
]

🔁 一連の処理の流れ

  1. DBからユーザーの過去会話を取得(LangMem + LangGraph)
  2. それを元にプロンプト最適化(Gradientモード)
  3. GPT-4o に渡して返答生成(必要なら LangMemのツール呼び出しも)
  4. 応答とユーザー発言をまたメモリに保存

🧪 補足:LangMemの Gradient モードとは?

  • 反映(Reflection) を最大8ステップまで繰り返して改善
  • 過去会話の「フィードバック」をもとに、systemプロンプトを最適化
  • 回数やステップ数を調整することで柔軟なプロンプト制御が可能

📌 まとめ

この prompt_optimizer を組み込むだけで、
AIが自分の“話し方”を過去の会話を通じて学んでくれるようになります!

ちょっとした工夫で、キャラが自然に成長していく体験が作れるのが楽しいです 🧠💬

🔗 もし動かしてみたい時は!

こちらの記事に沿って、環境構築をしていただいて、
https://qiita.com/dev-cat/items/59b01885416fae92c477

以下をrooterに追加し、main.pyでエンドポイント設定してもらえれば動きます!

chat_optimize_gradient.py
from typing import Any, Dict, List
from langgraph.store.memory import InMemoryStore
from openai import OpenAI
from langmem import create_manage_memory_tool, create_search_memory_tool
from pydantic import BaseModel
from fastapi import APIRouter, FastAPI
from langgraph.store.postgres import AsyncPostgresStore
from langmem import create_prompt_optimizer
import asyncio
import os
import json
import uuid
import time

optimizer = create_prompt_optimizer(
    "openai:gpt-4o",
    kind="gradient",  # gradient モードでプロンプト改善
    config={
        "max_reflection_steps": 8,
        "min_reflection_steps": 5
    }
)
router = APIRouter()
store = None
store_cm = None

async def init_store():
    global store
    global store_cm

    POSTGRES_URL = os.getenv("POSTGRES_URL")
    max_attempts = 5
    delay_sec = 2

    for attempt in range(max_attempts):
        try:
            print("Initializing store...")
            print(f"Try {attempt + 1}/{max_attempts}")

            store_cm = AsyncPostgresStore.from_conn_string(POSTGRES_URL)
            store = await store_cm.__aenter__()
            await store.setup()

            print("Store initialized!")
            return
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_attempts - 1:
                print(f"Retrying in {delay_sec} seconds...")
                await asyncio.sleep(delay_sec)
            else:
                print("All attempts failed.")

@router.on_event("startup")
async def startup_event():
    await init_store()

class ChatRequest(BaseModel):
    user_id: str
    user_message: str

@router.post("/chat_optimize_gradient")
async def chat(request: ChatRequest):
  global store

  if store is None:
      return {"error": "Store is not initialized"}

  user_id = request.user_id
  user_message = request.user_message

  namespace = ("memories", user_id)

  # ツール生成
  memory_tools = [
      create_manage_memory_tool(namespace=namespace, store=store),
      create_search_memory_tool(namespace=namespace, store=store),
  ]

  # 過去会話を取得
  try:
      user_memories = await store.asearch(namespace, limit=1000)
      print("user_memories取得成功:", user_memories)
  except Exception as e:
      print("user_memories取得失敗:", e)
      return {"error": f"user_memories取得失敗: {str(e)}"}

  if user_memories:
      print("user_memories[0] の中身:", user_memories[0])
  else:
      print("user_memories は空")

  past_conversation = []
  for memory in user_memories:
      try:
          role = memory.value.get("role")
          content = memory.value.get("content")
          timestamp = memory.value.get("timestamp", 0)

          print(f"memory.value: role={role}, content={repr(content)}, timestamp={timestamp}")

          if role in ["user", "assistant"] and content:
              past_conversation.append({
                  "role": role,
                  "content": content,
                  "timestamp": timestamp
              })
          else:
              print("スキップ対象のメモリ(不正なroleまたは空のcontent)")

      except Exception as e:
          print("memory.valueの処理中に例外発生:", e)

  # past_conversation = sorted(past_conversation, key=lambda x: x["timestamp"])[-3:]
  past_conversation = sorted(past_conversation, key=lambda x: x["timestamp"])

  # プロンプト最適化付きチャット応答
  result = await run_agent(
      tools=memory_tools,
      user_input=user_message,
      past_conversation=past_conversation,
  )

  # ユーザー発言と応答を記録
  await store.aput(
      namespace=("memories", user_id),
      key=str(uuid.uuid4()),  # もしくは任意の一意なキー
      value={
          "role": "user",
          "content": user_message,
          "timestamp": time.time()
      }
  )

  await store.aput(
      namespace=("memories", user_id),
      key=str(uuid.uuid4()),
      value={
          "role": "assistant",
          "content": result,
          "timestamp": time.time()
      }
  )

  # 直近メモリを返す
  new_memories = await store.asearch(namespace)
  serialized_memories = [m.dict() for m in new_memories]

  return {
      "message": result,
      "memories": serialized_memories
  }


async def execute_tool(tools_by_name: Dict[str, Any], tool_call: Dict[str, Any]) -> str:
    tool_name = tool_call["function"]["name"]

    print(f"Trying to execute tool: {tool_name}")

    if tool_name not in tools_by_name:
        return f"Error: Tool {tool_name} not found"

    tool = tools_by_name[tool_name]
    try:
        args = tool_call["function"]["arguments"]
        print(f"Tool args: {args}")

        if asyncio.iscoroutinefunction(tool.invoke):
            result = await tool.invoke(args)
        else:
            result = await asyncio.to_thread(tool.invoke, args)

        print(f"Tool result: {result}")
        return str(result)
    except Exception as e:
        print(f"Tool error: {e}")
        return f"Error executing {tool_name}: {str(e)}"

async def run_agent(tools: List[Any], user_input: str, past_conversation: List[Dict[str, str]] = [], base_prompt: str = "You're an energetic and friendly AI assistant who speaks casually like a close friend, responds in a helpful and engaging way to any topic, and can throw in jokes or fun expressions to make conversations more lively.", max_steps: int = 5) -> str:
    client = OpenAI()
    tools_by_name = {tool.name: tool for tool in tools}

    openai_tools = [
        {
            "type": "function",
            "function": {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.tool_call_schema.model_json_schema(),
            },
        }
        for tool in tools
    ]

    # プロンプト最適化
    try:
        if past_conversation:
            trajectories = [(past_conversation, {"feedback": "Make your tone casual and playful, like you're talking to a close friend. Avoid sounding too formal."})]
            better_prompt = await optimizer.invoke({
                "trajectories": trajectories,
                "prompt": base_prompt
            })
        else:
            better_prompt = base_prompt
    except Exception as e:
        print(f"Prompt optimization failed: {e}")
        better_prompt = base_prompt


    messages = [
        {"role": "system", "content": better_prompt},
        *past_conversation,
        {"role": "user", "content": user_input}
    ]

    for step in range(max_steps):
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=messages,
            tools=openai_tools if step < max_steps - 1 else [],
            tool_choice="auto",
            temperature=0.8,
            top_p=0.95
        )
        message = response.choices[0].message
        tool_calls = message.tool_calls

        if not tool_calls:
            return message.content

        messages.append(
            {"role": "assistant", "content": message.content, "tool_calls": tool_calls}
        )

        for tool_call in tool_calls:
            tool_result = await execute_tool(tools_by_name, tool_call.model_dump())
            messages.append({
                "role": "tool",
                "tool_call_id": tool_call.id,
                "content": tool_result,
            })

    return "Reached maximum number of steps"


asyncio.create_task(init_store())

🙋‍♂️ コメント歓迎!

質問・指摘・アイデア大歓迎です!
いいね・ストックいただけたら最高に励みになります 🙌

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?