0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

とりあえずこれ最近面白いやつby日本ChatGPT研究所(さくらりん)Advent Calendar 2024

Day 20

【連載】pydantic-ai徹底解説 (3) 構造化レスポンスとバリデーション(GoogleColab付)

Posted at

前回はツール呼び出しや依存性注入を学びました。今回は pydantic-ai の特徴的機能である「構造化レスポンス」について詳しく見ていきます。

構造化レスポンスとは

通常のLLM応答は生テキストですが、pydantic-aiではPydanticモデルで結果のスキーマを定義できます。
Agentにresult_typeを指定すると、LLMは最終結果を定義したスキーマに合うようJSONで回答するよう促されます。

これにより、

  • データが期待する構造を満たさない場合、Pydanticバリデーションによりエラーとなり、LLMに再試行させることが可能です(reflection & retry)
  • 開発者は純粋なPythonオブジェクトとして結果を扱えます

インストール方法

pydantic-aiとその例を実行するには、以下の手順でインストールを行います:

基本的なインストール


# pipを使用する場合
!pip install 'pydantic-ai[examples]' loguru

import os
from google.colab import userdata

os.environ['GEMINI_API_KEY'] = userdata.get('GEMINI_API_KEY')
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

import nest_asyncio
nest_asyncio.apply()

簡易例

以下は非常にシンプルな例です。

from pydantic_ai import Agent
from pydantic import BaseModel
import pprint
class CityLocation(BaseModel):
    city: str
    country: str

agent = Agent('gemini-1.5-flash', result_type=CityLocation)
result = agent.run_sync('Where were the Olympics held in 2012?')
# print(result.data)
pprint.pprint(result)
# 期待出力: CityLocation(city='London', country='United Kingdom')

LLMはCityLocationというPydanticモデルに適合するよう出力します。
もし間違った形式で返せば、pydantic-aiはLLMに再度試行を促します(リトライ機構)。

リトライ機構

ツール呼び出しや結果バリデーション時にエラーがあった場合、ModelRetry例外を発生させることでLLMに再チャレンジを要求できます。
これにより、LLMが初回で正しくフォーマットできなくても、再回答のチャンスが与えられ、最終的に有効な構造化データを得やすくなります。

ストリーミング応答

pydantic-aiはストリーミングにも対応しています。
LLMが段階的にJSONデータを生成する場合、部分的なJSONを段階的にパース・バリデーションして、進行中の状態をリアルタイムで表示・利用することができます。

例えば、あるTypedDictの構造体をresult_typeに指定してrun_stream()で実行すると、stream()メソッドで部分的なJSONが流れてくるごとにバリデーションしつつ処理可能です。

例:ストリーミングで構造化データ取得

from typing_extensions import TypedDict
from pydantic_ai import Agent

class UserProfile(TypedDict, total=False):
    name: str
    dob: str
    bio: str

agent = Agent('openai:gpt-4o', result_type=UserProfile)

async with agent.run_stream('Describe the user John with partial streaming.') as res:
    async for partial_data in res.stream():
        print("Partial validated data:", partial_data)

このように、LLMがデータを徐々に出力している場合でも、整合性チェックが行われ、常に型安全な状態で処理できます。

高度な例1:複雑なネスト構造

以下は、より複雑なスキーマを扱う例です。
複数のユーザープロフィールを格納するUserDirectoryモデルを定義し、UserProfileUserAddressといった階層的なデータ構造をバリデーション付きで受け取ります。
出力はloguruでログを残し、プロンプトは日本語で与えています。

from pydantic_ai import Agent
from pydantic import BaseModel, Field
from typing import List
from loguru import logger
import pprint
import sys


class UserAddress(BaseModel):
    country: str
    city: str
    postal_code: str

class UserProfile(BaseModel):
    name: str = Field(..., description="ユーザー名")
    dob: str = Field(..., description="生年月日。例:1990-01-01")
    bio: str = Field(..., description="ユーザー略歴")
    addresses: List[UserAddress] = Field(..., description="住所一覧")

class UserDirectory(BaseModel):
    users: List[UserProfile] = Field(..., description="ユーザープロフィール集")

agent = Agent('gemini-1.5-flash', result_type=UserDirectory)

prompt = """
以下の条件で、有名な架空のキャラクターを3人挙げ、それぞれのプロフィールをJSONで出力してください:
- 各ユーザーには 'name', 'dob', 'bio', 'addresses' が必要
- 'addresses' は必須で、 'country', 'city', 'postal_code' を含める
- JSON全体を { "users": [ ... ] } の形で出力
- 不正な形式ならリトライさせられます
DOBは架空の日付でOKです。
"""

logger.info("🤖 LLMへプロンプトを送信開始...")
result = agent.run_sync(prompt)
logger.success("✅ LLMからレスポンスを受信完了")

# 各ユーザープロフィールの詳細をログ出力
logger.info("📊 生成されたプロフィール詳細:")
for i, user in enumerate(result.data.users, 1):  # この行を修正
    logger.info("=" * 50)
    logger.info(f"👤 ユーザー {i}: {user.name}")
    logger.info(f"🎂 生年月日: {user.dob}")
    logger.info(f"📝 略歴: {user.bio}")
    logger.info("📍 住所:")
    for j, addr in enumerate(user.addresses, 1):
        logger.info(f"  住所 {j}:")
        logger.info(f"    国: {addr.country}")
        logger.info(f"    市: {addr.city}")
        logger.info(f"    郵便番号: {addr.postal_code}")

# デバッグ情報の出力
logger.debug("🔍 生データ:")
logger.debug(pprint.pformat(result.data.dict(), indent=2))  # この行も修正

# コスト情報の出力を追加
logger.info("💰 API使用コスト:")
logger.info(f"  リクエストトークン数: {result._cost.request_tokens}")
logger.info(f"  レスポンストークン数: {result._cost.response_tokens}")
logger.info(f"  合計トークン数: {result._cost.total_tokens}")
pprint.pprint(result)

この例では、3人の架空キャラクター(例:「シャーロック・ホームズ」「ドラクエの主人公」「宮本武蔵」など)のプロフィールをJSONで取得できます。バリデーションに失敗すると内部で再試行が行われ、最終的にUserDirectoryに適合する結果が得られます。

高度な例2:EnumとUnion型を含む複雑スキーマ

次に、enumUnion型を使い、より厳密な制約と柔軟なデータ構造を要求する例を示します。
ここでは架空の「商品カタログ」を扱います。

  • Producttypeフィールドをもち、"physical""digital"でenum制約を行います。
  • physical型ではweightdimensionsが必須、digital型ではfile_sizeが必須など、Unionを使って異なるスキーマを要求します。
  • このスキーマに合ったJSONをLLMに要求し、間違ったらリトライさせます。
from pydantic_ai import Agent
from pydantic import BaseModel, Field, ValidationError
from typing import List, Union
from enum import Enum
from loguru import logger
import sys
import pprint


class ProductType(str, Enum):
    PHYSICAL = "physical"
    DIGITAL = "digital"

class PhysicalProduct(BaseModel):
    type: ProductType = Field(ProductType.PHYSICAL)
    name: str
    weight: float = Field(..., description="重量(kg)")
    dimensions: str = Field(..., description="サイズ表記。例:10x20x30cm")

class DigitalProduct(BaseModel):
    type: ProductType = Field(ProductType.DIGITAL)
    name: str
    file_size: float = Field(..., description="ファイルサイズ(MB)")
    download_url: str = Field(..., description="ダウンロードURL")

Product = Union[PhysicalProduct, DigitalProduct]

class ProductCatalog(BaseModel):
    products: List[Product]

agent = Agent('openai:gpt-4o', result_type=ProductCatalog)

prompt = """
猫用品専門店「猫猫カンパニー」の商品カタログをJSONで出力してください:
- 全体は { "products": [ ... ] } という構造
- 各productは以下のいずれかのタイプを含むこと:

1. 物理製品 (type="physical")
   - キャットタワー、おもちゃ、食器などの実物商品
   - 必須フィールド: "name", "weight", "dimensions"

2. デジタル製品 (type="digital")
   - 猫のしつけ動画講座、電子書籍など
   - 必須フィールド: "name", "file_size", "download_url"

合計5つ以上の商品を含めてください。
不正な形式の場合は自動的に再試行します。
"""

logger.info("🐱 猫猫カンパニーの商品カタログ生成を開始します...")
result = agent.run_sync(prompt)
logger.success("✨ カタログデータの受信が完了しました")

# 商品カテゴリごとの集計
physical_products = [p for p in result.data.products if p.type == ProductType.PHYSICAL]
digital_products = [p for p in result.data.products if p.type == ProductType.DIGITAL]

# カタログ概要の出力
logger.info("📊 カタログ概要")
logger.info("=" * 50)
logger.info(f"総商品数: {len(result.data.products)}")
logger.info(f"物理商品: {len(physical_products)}")
logger.info(f"デジタル商品: {len(digital_products)}")
logger.info("=" * 50)

# 物理商品の詳細出力
logger.info("\n🛍️ 物理商品一覧:")
for i, product in enumerate(physical_products, 1):
    logger.info("-" * 40)
    logger.info(f"商品{i}: {product.name}")
    logger.info(f"  重量: {product.weight}kg")
    logger.info(f"  サイズ: {product.dimensions}")

# デジタル商品の詳細出力
logger.info("\n💻 デジタル商品一覧:")
for i, product in enumerate(digital_products, 1):
    logger.info("-" * 40)
    logger.info(f"商品{i}: {product.name}")
    logger.info(f"  サイズ: {product.file_size}MB")
    logger.info(f"  URL: {product.download_url}")

# API使用状況の出力
logger.info("\n📈 API使用状況")
logger.info("=" * 50)
logger.info(f"リクエストトークン: {result._cost.request_tokens}")
logger.info(f"レスポンストークン: {result._cost.response_tokens}")
logger.info(f"合計トークン: {result._cost.total_tokens}")

# デバッグ情報
logger.debug("\n🔍 生データ:")
logger.debug(pprint.pformat(result.data.dict(), indent=2))

この例では、LLMはProductCatalogというスキーマに適合するJSONを生成する必要があります。
physical製品やdigital製品が混在し、フィールド要件が異なるため、LLMが不正な出力をした場合はpydantic-aiがエラーを投げて再試行が行われます。
最終的にはproducts: List[Union[PhysicalProduct, DigitalProduct]]という型に合ったPythonオブジェクトが取得できます。

高度な例3:ツール呼び出しと構造化データの組み合わせ

最後に、前回学んだツール呼び出し機能と構造化レスポンスを組み合わせた例を示します。

猫猫カンパニーの店舗天気情報取得システム

このスクリプトは各店舗の天気情報を取得し、お客様への情報提供や
店舗運営の参考にするためのものです。

  • 天気情報の取得と構造化
  • loguruを使用した詳細なログ記録
  • 店舗ごとの状況可視化
"""
猫猫カンパニーの店舗天気情報取得システム

このスクリプトは各店舗の天気情報を取得し、お客様への情報提供や
店舗運営の参考にするためのものです。
- 天気情報の取得と構造化
- loguruを使用した詳細なログ記録
- 店舗ごとの状況可視化
"""

from pydantic_ai import Agent, RunContext
from pydantic import BaseModel, Field
from loguru import logger
import sys
from datetime import datetime
from dataclasses import dataclass
from typing import Any

# ロガーの設定をカスタマイズ
logger.remove()
logger.add(
    sys.stdout,
    format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
    colorize=True,
    level="INFO"
)
logger.add("cat_shop_weather.log", rotation="500 MB", level="DEBUG")

class WeatherInfo(BaseModel):
    city: str = Field(..., description="店舗所在地")
    temperature: float = Field(..., description="気温(℃)")
    condition: str = Field(..., description="天気状態")

@dataclass
class WeatherDeps:
    dummy: bool = True  # 実際のAPIを使用しない場合はTrue

# エージェントの設定
weather_agent = Agent(
    'gemini-1.5-flash',
    system_prompt='猫猫カンパニーの店舗天気情報アシスタントとして応答します。簡潔に回答してください。',
    deps_type=WeatherDeps,
    result_type=WeatherInfo
)

# ツールの定義
@weather_agent.tool
async def get_weather(ctx: RunContext[WeatherDeps], city: str) -> dict[str, Any]:
    """指定された都市の天気情報を取得する

    Args:
        ctx: コンテキスト
        city: 店舗所在地
    """
    # ダミーデータを返す
    weather_data = {
        "東京": {"temperature": 22.5, "condition": "晴れ"},
        "大阪": {"temperature": 24.0, "condition": "曇り"},
        "名古屋": {"temperature": 23.0, "condition": "晴れ時々曇り"},
        "福岡": {"temperature": 25.0, "condition": "晴れ"},
        "札幌": {"temperature": 18.0, "condition": "小雨"}
    }

    if city in weather_data:
        return {
            "city": city,
            "temperature": weather_data[city]["temperature"],
            "condition": weather_data[city]["condition"]
        }
    return {
        "city": city,
        "temperature": 20.0,
        "condition": "不明"
    }

async def main():
    # 処理開始のログ
    logger.info("🏪 猫猫カンパニー店舗天気情報システム")
    logger.info("=" * 50)
    logger.info(f"📅 取得日時: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info("=" * 50)

    deps = WeatherDeps()
    prompt = """
    猫猫カンパニー東京本店(東京)の現在の天気情報を取得してください。
    この情報は店舗の環境管理(空調設定など)に使用されます。
    """

    try:
        logger.info("🌤️ 天気情報の取得を開始します...")
        result = await weather_agent.run(prompt, deps=deps)
        logger.success("✨ 天気情報の取得が完了しました")

        # 取得した天気情報の表示
        logger.info("\n📍 店舗天気情報")
        logger.info("=" * 50)
        logger.info(f"店舗所在地: {result.data.city}")
        logger.info(f"現在の気温: {result.data.temperature}")
        logger.info(f"天気状態 : {result.data.condition}")

        # 気温に基づくアドバイス
        logger.info("\n💡 運営アドバイス")
        logger.info("-" * 40)
        if result.data.temperature > 25:
            logger.warning("🌡️ 気温が高めです。空調の設定を確認してください")
        elif result.data.temperature < 18:
            logger.warning("🌡️ 気温が低めです。暖房の設定を確認してください")
        else:
            logger.info("🌡️ 快適な温度範囲です")

        if result.data.condition in ["", "小雨"]:
            logger.info("☔ 雨天のため、入口付近の床の清掃頻度を上げることをお勧めします")

        # API使用状況
        logger.info("\n📊 API使用状況")
        logger.info("-" * 40)
        logger.info(f"リクエストトークン: {result._cost.request_tokens}")
        logger.info(f"レスポンストークン: {result._cost.response_tokens}")
        logger.info(f"合計トークン: {result._cost.total_tokens}")

    except Exception as e:
        logger.error(f"❌ エラーが発生しました: {str(e)}")
        raise

    finally:
        logger.info("\n🏁 処理を終了します")
        logger.info("=" * 50)

if __name__ == "__main__":
    import asyncio
    asyncio.run(main())

上記例では、LLMは「東京の天気」を取得するためにget_weatherツールを呼び出し、その戻り値をWeatherInfoとして返すよう促されます。
もし不正な形式のJSONを返せば内部的にリトライが行われます。

まとめ

今回は構造化レスポンスとバリデーション、リトライ機構、ストリーミング対応、そして複雑なスキーマやenum・Union型、ツール呼び出しとの組み合わせ例を示しました。
pydantic-aiは、LLMとのやりとりを「ただの文字列交換」から「型安全でエラーハンドリングしやすいデータ交換」へと発展させます。

次回はユニットテストや評価(Evals)の方法、TestModelFunctionModelによるモック化、Logfireとの統合など、開発・運用フェーズで役立つ機能を紹介していきます。

📒ノートブック

参考サイト

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?