1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LangChainでSQL Agentを試す

Posted at

LangChainのv1が(未リリースながら)使えたので、SQL Agentのシナリオを試してみました。以下の内容を(少し改変して)試しています。

Build an agent with minimal code

LnagGraphを使わず、最低限の実装。全スキーマ情報をプロンプトに含めてLLMへ渡している。

0. 前提

Ubuntu 22.04 で Python 3.13.2で実装。
仮想環境をpoetry使って管理しており、以下のパッケージ使用。

project.toml
jupyterlab = "^4.4.7"
langchain-community = "^0.3.29"
langchainhub = "^0.1.21"
langgraph = "^0.6.7"
python-dotenv = "^1.1.1"
langchain = {extras = ["openai"], version = "^1.0.0a9", allow-prereleases = true}

.envにKeyなど定義しています。

.env
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=https://<resource>.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4.1-nano
AZURE_OPENAI_API_VERSION=2024-12-01-preview
LANGSMITH_API_KEY=
LANGSMITH_TRACING=true

1. 初期処理

テスト用に小さいモデルを使っています。LangSmithのトレースも入れています。

import os
import getpass

from langchain.chat_models import init_chat_model
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

llm = init_chat_model(
    "azure_openai:gpt-4.1-nano",
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
    api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)

if not os.environ.get("LANGSMITH_API_KEY"):
    os.environ["LANGSMITH_API_KEY"] = getpass.getpass()
    os.environ["LANGSMITH_TRACING"] = "true"

2. Configure the database

よくあるSQLiteのベース設定。

import requests, pathlib

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("Chinook.db")

if local_path.exists():
    print(f"{local_path} already exists, skipping download.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        local_path.write_bytes(response.content)
        print(f"File downloaded and saved as {local_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

3. Add tools for database interactions

スキーマ情報を取得。

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
SCHEMA = db.get_table_info()

中身を見ます。

print(SCHEMA)
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	leonekohler@surfeu.de	5
3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	ftremblay@gmail.com	3
*/


CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60), 
	PRIMARY KEY ("EmployeeId"), 
	FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId	LastName	FirstName	Title	ReportsTo	BirthDate	HireDate	Address	City	State	Country	PostalCode	Phone	Fax	Email
1	Adams	Andrew	General Manager	None	1962-02-18 00:00:00	2002-08-14 00:00:00	11120 Jasper Ave NW	Edmonton	AB	Canada	T5K 2N1	+1 (780) 428-9482	+1 (780) 428-3457	andrew@chinookcorp.com
2	Edwards	Nancy	Sales Manager	1	1958-12-08 00:00:00	2002-05-01 00:00:00	825 8 Ave SW	Calgary	AB	Canada	T2P 2T3	+1 (403) 262-3443	+1 (403) 262-3322	nancy@chinookcorp.com
3	Peacock	Jane	Sales Support Agent	2	1973-08-29 00:00:00	2002-04-01 00:00:00	1111 6 Ave SW	Calgary	AB	Canada	T2P 5M5	+1 (403) 262-3443	+1 (403) 262-6712	jane@chinookcorp.com
*/


CREATE TABLE "Genre" (
	"GenreId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("GenreId")
)

/*
3 rows from Genre table:
GenreId	Name
1	Rock
2	Jazz
3	Metal
*/


CREATE TABLE "Invoice" (
	"InvoiceId" INTEGER NOT NULL, 
	"CustomerId" INTEGER NOT NULL, 
	"InvoiceDate" DATETIME NOT NULL, 
	"BillingAddress" NVARCHAR(70), 
	"BillingCity" NVARCHAR(40), 
	"BillingState" NVARCHAR(40), 
	"BillingCountry" NVARCHAR(40), 
	"BillingPostalCode" NVARCHAR(10), 
	"Total" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("InvoiceId"), 
	FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
*/


CREATE TABLE "InvoiceLine" (
	"InvoiceLineId" INTEGER NOT NULL, 
	"InvoiceId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	"Quantity" INTEGER NOT NULL, 
	PRIMARY KEY ("InvoiceLineId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
1	1	2	0.99	1
2	1	4	0.99	1
3	2	6	0.99	1
*/


CREATE TABLE "MediaType" (
	"MediaTypeId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("MediaTypeId")
)

/*
3 rows from MediaType table:
MediaTypeId	Name
1	MPEG audio file
2	Protected AAC audio file
3	Protected MPEG-4 video file
*/


CREATE TABLE "Playlist" (
	"PlaylistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("PlaylistId")
)

/*
3 rows from Playlist table:
PlaylistId	Name
1	Music
2	Movies
3	TV Shows
*/


CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/


CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
	FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
	FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
3 rows from Track table:
TrackId	Name	AlbumId	MediaTypeId	GenreId	Composer	Milliseconds	Bytes	UnitPrice
1	For Those About To Rock (We Salute You)	1	1	1	Angus Young, Malcolm Young, Brian Johnson	343719	11170334	0.99
2	Balls to the Wall	2	2	1	U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann	342562	5510424	0.99
3	Fast As a Shark	3	2	1	F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman	230619	3990994	0.99
*/

4. Execute SQL queries

まずは、更新系のSQLをエラーとする関数定義。実務上は必須だが、技術上は必須でない処理。

import re
from langchain_core.tools import tool
DENY_RE = re.compile(r"\b(INSERT|UPDATE|DELETE|ALTER|DROP|CREATE|REPLACE|TRUNCATE)\b", re.I)
HAS_LIMIT_TAIL_RE = re.compile(r"(?is)\blimit\b\s+\d+(\s*,\s*\d+)?\s*;?\s*$")

def _safe_sql(q: str) -> str:
    # normalize
    q = q.strip()
    # block multiple statements (allow one optional trailing ;)
    if q.count(";") > 1 or (q.endswith(";") and ";" in q[:-1]):
        return "Error: multiple statements are not allowed."
    q = q.rstrip(";").strip()

    # read-only gate
    if not q.lower().startswith("select"):
        return "Error: only SELECT statements are allowed."
    if DENY_RE.search(q):
        return "Error: DML/DDL detected. Only read-only queries are permitted."

    # append LIMIT only if not already present at the end (robust to whitespace/newlines)
    if not HAS_LIMIT_TAIL_RE.search(q):
        q += " LIMIT 5"
    return q

tool定義。

@tool
def execute_sql(query: str) -> str:
    """Execute a READ-ONLY SQLite SELECT query and return results."""
    query = _safe_sql(query)
    q = query
    if q.startswith("Error:"):
        return q
    try:
        return db.run(q)
    except Exception as e:
        return f"Error: {e}"

5. Use create_agent

SYSTEM = f"""You are a careful SQLite analyst.

Authoritative schema (do not invent columns/tables):
{SCHEMA}

Rules:
- Think step-by-step.
- When you need data, call the tool `execute_sql` with ONE SELECT query.
- Read-only only; no INSERT/UPDATE/DELETE/ALTER/DROP/CREATE/REPLACE/TRUNCATE.
- Limit to 5 rows unless user explicitly asks otherwise.
- If the tool returns 'Error:', revise the SQL and try again.
- Limit the number of attempts to 5.
- If you are not successful after 5 attempts, return a note to the user.
- Prefer explicit column lists; avoid SELECT *.
"""

Agent作成。

from langchain.agents import create_agent
from langchain_core.messages import SystemMessage
agent = create_agent(
    model=llm,
    tools=[execute_sql],
    prompt=SystemMessage(content=SYSTEM),
)

5. Run the agent

question = "平均して最も曲の長さが長いジャンルはどれですか?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()
結果
================================ Human Message =================================

平均して最も曲の長さが長いジャンルはどれですか?
================================== Ai Message ==================================
Tool Calls:
  execute_sql (call_HrX9KM9tI7jIPckNpTudbvKU)
 Call ID: call_HrX9KM9tI7jIPckNpTudbvKU
  Args:
    query: SELECT Genre.Name, AVG(Track.Milliseconds) AS AvgMilliseconds FROM Track JOIN Genre ON Track.GenreId = Genre.GenreId GROUP BY Genre.Name ORDER BY AvgMilliseconds DESC LIMIT 1
================================= Tool Message =================================
Name: execute_sql

[('Sci Fi & Fantasy', 2911783.0384615385)]
================================== Ai Message ==================================

最も曲の長さが長いジャンルは「Sci Fi & Fantasy」です。

LangSmithで見ます。
最初のAgentでText2SQLをして、Tool実行、Tool結果に対する最終回答取得の流れです。

image.png

Build a customized workflow

前提は先ほどと同じです。別Jupyterで実行しているので、少し重複箇所があります。
先ほどと違って、LangGraphを使って細かく指定をしています。
最初に顧客情報抽出処理定義したりと、用途限定で実装。

1. Initialize the model and database

初期設定

import os
import getpass

from langchain.chat_models import init_chat_model
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

llm = init_chat_model(
    "azure_openai:gpt-4.1-nano",
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
    api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)

if not os.environ.get("LANGSMITH_API_KEY"):
    os.environ["LANGSMITH_API_KEY"] = getpass.getpass()
    os.environ["LANGSMITH_TRACING"] = "true"

import pathlib
import requests

# Initialize the database

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("Chinook.db")

if local_path.exists():
    print(f"{local_path} already exists, skipping download.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        local_path.write_bytes(response.content)
        print(f"File downloaded and saved as {local_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
SCHEMA = db.get_table_info()

2. ステート定義

ステートを定義します。

from typing import Optional

from langgraph.graph import MessagesState

# Graph State
class GraphState(MessagesState):
    first_name: Optional[str]
    last_name: Optional[str]
    customer: bool
    customer_id: Optional[int]

3. ツール定義

さっきも定義した関数と役割は同じで、SQLのチェックです。先ほどよりも多くのチェックを入れています。

import re
# --- Policy configuration ------------------------------------------------------

# Tables a customer is allowed to read
CUSTOMER_ALLOWLIST = {
    "invoice",
    "invoiceline",
    "track",
    "album",
    "artist",
    "genre",
    "mediatype",
    "playlist",
    "playlisttrack",
}

# Tables that are customer-scoped (must include CustomerId = :customer_id)
CUSTOMER_SCOPED = {"invoice", "invoiceline"}

# --- Safety regexes ------------------------------------------------------------
DENY_RE = re.compile(r"\b(INSERT|UPDATE|DELETE|ALTER|DROP|CREATE|REPLACE|TRUNCATE)\b", re.I)
HAS_LIMIT_TAIL_RE = re.compile(r"(?is)\blimit\b\s+\d+(\s*,\s*\d+)?\s*;?\s*$")

# Disallow non-plain-select constructs to keep verification simple
NON_PLAIN_SQL_RE = re.compile(r"\b(with|union|intersect|except)\b|\(\s*select\b", re.I)

# Extract FROM/JOIN tables & aliases (very lightweight parsing)
FROM_RE = re.compile(r"\bfrom\s+([\"`\[]?\w+[\"`\]]?)(?:\s+as\s+(\w+)|\s+(\w+))?", re.I)
JOIN_RE = re.compile(r"\bjoin\s+([\"`\[]?\w+[\"`\]]?)(?:\s+as\s+(\w+)|\s+(\w+))?", re.I)

# Simple checks around CustomerId usage
CUSTID_PLACEHOLDER_EQ_RE = re.compile(r"\b(?:\w+\.)?customerid\s*=\s*:customer_id\b", re.I)
CUSTID_NUMERIC_EQ_RE     = re.compile(r"\b(?:\w+\.)?customerid\s*=\s*\d+\b", re.I)


def _normalize_ident(name: str) -> str:
    # strip quotes/backticks/brackets and lower-case
    return re.sub(r'^[\"`\[]|[\"`\]]$', '', name).lower()


def _extract_tables_and_aliases(q: str):
    tables = set()
    alias_map = {}  # alias -> base table (lower-cased)
    for m in FROM_RE.finditer(q):
        base = _normalize_ident(m.group(1))
        alias = (m.group(2) or m.group(3) or "").lower()
        tables.add(base)
        if alias:
            alias_map[alias] = base
    for m in JOIN_RE.finditer(q):
        base = _normalize_ident(m.group(1))
        alias = (m.group(2) or m.group(3) or "").lower()
        tables.add(base)
        if alias:
            alias_map[alias] = base
    return tables, alias_map


def _safe_sql(q: str, customer_id: int) -> str:
    # normalize
    q = q.strip()
    # block multiple statements (allow one optional trailing ;)
    if q.count(";") > 1 or (q.endswith(";") and ";" in q[:-1]):
        return "Error: multiple statements are not allowed."
    q = q.rstrip(";").strip()

    # read-only gate
    if not q.lower().startswith("select"):
        return "Error: only SELECT statements are allowed."
    if DENY_RE.search(q):
        return "Error: DML/DDL detected. Only read-only queries are permitted."

    # plain-select only (no CTEs, subqueries, UNION/INTERSECT/EXCEPT)
    if NON_PLAIN_SQL_RE.search(q):
        return "Error: only plain SELECTs (no CTEs/subqueries/UNION/INTERSECT/EXCEPT) are allowed."

    # gather referenced tables & aliases
    tables, alias_map = _extract_tables_and_aliases(q)
    if not tables:
        return "Error: could not determine referenced tables."

    # allowlist enforcement
    disallowed = {t for t in tables if t not in CUSTOMER_ALLOWLIST}
    if disallowed:
        bad = ", ".join(sorted(disallowed))
        return f"Error: access to tables [{bad}] is not permitted."

    # customer-scoped enforcement
    needs_customer_filter = bool(CUSTOMER_SCOPED & tables)
    if needs_customer_filter:
        # forbid numeric literals for CustomerId
        if CUSTID_NUMERIC_EQ_RE.search(q):
            return "Error: use the :customer_id placeholder (no numeric literals) for CustomerId."

        # require a CustomerId = :customer_id predicate in the query text
        if not CUSTID_PLACEHOLDER_EQ_RE.search(q):
            return "Error: queries touching Invoice/InvoiceLine must include CustomerId = :customer_id."

        # Special rule for InvoiceLine: must also reference Invoice (joined)
        if "invoiceline" in tables and "invoice" not in tables:
            return "Error: queries referencing InvoiceLine must also join Invoice and filter by CustomerId = :customer_id."

    # append LIMIT if not present at the end (robust to whitespace/newlines)
    if not HAS_LIMIT_TAIL_RE.search(q):
        q += " LIMIT 5"
    return q

肝心のツール定義です。先ほどのツールに対してstateを引数にしたり、customer_idを抽出したりと拡張されています。
@tool(parse_docstring=True) を付けると、Google 形式の docstring からツールの説明と各引数の説明を自動抽出して、args_schema(Pydantic スキーマ)に反映。

from typing import Annotated
from langgraph.prebuilt import InjectedState

@tool(parse_docstring=True)
def execute_sql(
    query: str,
    state: Annotated[GraphState, InjectedState],  # provides access to customer_id
) -> str:
    """Execute a READ-ONLY SQLite SELECT query (customer-scoped) and return results.

    Args:
        query: a string containing a valid SQL query

    Returns:
        A string with the response to the query or an error
    """
    customer_id = int(state["customer_id"])
    safe_q = _safe_sql(query, customer_id)
    if safe_q.startswith("Error:"):
        return safe_q
    try:
        # Bind the named parameter expected by the query (:customer_id)
        return db.run(safe_q, parameters={"customer_id": customer_id})
    except Exception as e:
        return f"Error: {e}"
SYSTEM = """You are a careful SQLite analyst.

Authoritative schema (do not invent columns/tables):
{SCHEMA}

Always use the `:customer_id` placeholder; never hardcode IDs or use names.
The system binds the actual value at execution.

Rules:
- Think step-by-step.
- When you need data, call the tool `execute_sql` with ONE SELECT query.
- Read-only only; no INSERT/UPDATE/DELETE/ALTER/DROP/CREATE/REPLACE/TRUNCATE.
- Limit to 5 rows unless the user explicitly asks otherwise.
- If the tool returns 'Error:', revise the SQL and try again.
- Limit the number of attempts to 5.
- If you are not successful after 5 attempts, return a note to the user.
- Prefer explicit column lists; avoid SELECT *.
"""

4. ノードとエッジの追加

各ノードとエッジを追加します。

import re

_ID_RE = re.compile(r"\b\d+\b")  # first integer in the run() string

def identify_node(state: GraphState) -> GraphState:
    first = (state.get("first_name") or "").strip()
    last  = (state.get("last_name") or "").strip()

    if not (first and last):
        return {}  # nothing to change

    # simple quote escaping for SQL string literal
    sf = first.replace("'", "''")
    sl = last.replace("'", "''")

    try:
        cust_raw = db.run(
            "SELECT CustomerId FROM Customer "
            f"WHERE FirstName = '{sf}' AND LastName = '{sl}' "
            "LIMIT 1"
        )
        if not cust_raw:
            return {}  # no change

        m = _ID_RE.search(cust_raw)
        if not m:
            # couldn't parse an ID; don't crash—just no update
            return {}

        customer_id = int(m.group(0))
        return {
            "customer": True,
            "customer_id": customer_id,
        }

    except Exception as e:
        print(f"Customer lookup failed: {e}")
        return {}

# conditional edge
def route_from_identify(state: GraphState):
    # Continue only if an ID is present; otherwise END
    if state.get("employee_id") or state.get("customer_id"):
        return "llm"
    return "unknown_user"

# Node Return Unknown User Message
def unknown_user_node(state: GraphState):
    return {
        "messages": AIMessage(
            f"The user, first_name:{state.get('first_name','missing')}, "
            f"last_name:{state.get('last_name','missing')} is not in the database"
        )
    }
# Node LLM ReAct step
from langgraph.prebuilt import ToolNode
from langchain_core.messages import AIMessage, SystemMessage
from langgraph.graph import StateGraph, END

model_with_tools = llm.bind_tools([execute_sql])

def llm_node(state: GraphState) -> GraphState:
    msgs = [SystemMessage(content=SYSTEM.format(SCHEMA=SCHEMA))] + state["messages"]
    ai: AIMessage = model_with_tools.invoke(msgs)
    return { "messages": [ai]}

def route_from_llm(state: GraphState):
    last = state["messages"][-1]
    if isinstance(last, AIMessage) and getattr(last, "tool_calls", None):
        return "tools"
    return END


# Node : Tool execution
tool_node = ToolNode([execute_sql])
# Build Graph
builder = StateGraph(GraphState)

builder.add_node("identify", identify_node)
builder.add_node("unknown_user", unknown_user_node)
builder.add_node("llm", llm_node)
builder.add_node("tools", tool_node)

builder.set_entry_point("identify")
builder.add_conditional_edges("identify", route_from_identify, {"llm": "llm", "unknown_user": "unknown_user"})
builder.add_conditional_edges("llm", route_from_llm, {"tools": "tools", END: END})
builder.add_edge("tools", "llm")

graph = builder.compile()
from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles

display(Image(graph.get_graph().draw_mermaid_png()))

こんなGraphです。
image.png

5. 実行

実行します。

question = "直近の3請求を教えて"
for step in graph.stream(
    {"messages": [{"role": "user", "content": question}],
      "first_name": "Frank",
      "last_name": "Harris",
    },
    stream_mode="values",
):
    step["messages"][-1].pretty_print()
================================[1m Human Message [0m=================================

直近の3請求を教えて
================================[1m Human Message [0m=================================

直近の3請求を教えて
==================================[1m Ai Message [0m==================================
Tool Calls:
  execute_sql (call_AWiK5Vck4Gp43dvwdbfsl4Yz)
 Call ID: call_AWiK5Vck4Gp43dvwdbfsl4Yz
  Args:
    query: SELECT InvoiceId, CustomerId, InvoiceDate, Total FROM Invoice ORDER BY InvoiceDate DESC LIMIT 3;
=================================[1m Tool Message [0m=================================
Name: execute_sql

Error: queries touching Invoice/InvoiceLine must include CustomerId = :customer_id.
==================================[1m Ai Message [0m==================================
Tool Calls:
  execute_sql (call_KTF00DD5UAsVG4ZsKEnAwB9t)
 Call ID: call_KTF00DD5UAsVG4ZsKEnAwB9t
  Args:
    query: SELECT InvoiceId, CustomerId, InvoiceDate, Total FROM Invoice WHERE CustomerId = :customer_id ORDER BY InvoiceDate DESC LIMIT 3;
=================================[1m Tool Message [0m=================================
Name: execute_sql

[(374, 16, '2025-07-04 00:00:00', 5.94), (352, 16, '2025-04-01 00:00:00', 3.96), (329, 16, '2024-12-28 00:00:00', 1.98)]
==================================[1m Ai Message [0m==================================

直近の3つの請求は以下のとおりです。

1. 請求ID: 374、顧客ID: 16、請求日: 2025-07-04、合計金額: 5.94
2. 請求ID: 352、顧客ID: 16、請求日: 2025-04-01、合計金額: 3.96
3. 請求ID: 329、顧客ID: 16、請求日: 2024-12-28、合計金額: 1.98

LangSmithのログです。以下がその処理の流れ。

  1. identify: 顧客を抽出
  2. llm: Text2SQL
  3. tools: SQL を実行しようとしたが、チェックエラー。_safe_sql関数で実装したwhere条件にcustomer_id指定をするという条件を満たしていないため。
  4. llm: 再度Text2SQL
  5. tools: SQL実行
  6. llm: 最終回答

image.png
image.png
image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?