概要
OpenAI の Chat Completions API に Function Calling と呼ばれる新しい機能が追加されました。これを使うと、JSON Schema で指定した型に従う JSON 文字列を GPT が返してくれます。(ただし 100% valid な JSON が返る保証は無い)
特にライブラリに頼らずに Function Calling を利用しようとすると、
- JSON Schema を書く
- OpenAI API にリクエスト
- 返ってきた JSON 文字列をパースして検証する
というステップが必要になって面倒ですが、Zod のスキーマさえ書けば 1, 3 の手間を省きつつ型安全に実装することができます。
目指す姿
import { z } from "zod";
// 手動でコードを書くのは関数の名前、説明文、Zod スキーマの部分だけ
const functions = {
    add: {
      description: "Add two numbers together",
      parameterSchema: z.object({
        x: z.number(),
        y: z.number(),
      }),
    },
    getWeather: {
      description: "Get the weather for a given location",
      parameterSchema: z.object({
        location: z.string(),
      }),
    },
};
// OpenAI の API にリクエストして、呼び出された関数の名前と引数が以下の型で返る
// {name: "add", x: "number", y: "number} | {name: "getWeather", location: string}
const res = await functionCalling(prompt, functions)
// タグ付きユニオン型なので switch で絞り込める
switch (res.name) {
    case "add":
        console.log(res.x, res.y) // OK
        console.log(res.location) // 型エラー
        break
    case "getWeather":
        console.log(res.location) // OK
        console.log(res.x, res.y) // 型エラー
        break
    default:
        // 全てのケースが網羅されているかチェック
        res satisfies never
}
実装例
import { CreateChatCompletionRequest, OpenAIApi } from "openai"
import { ZodRawShape, z } from "zod"
import { zodToJsonSchema } from "zod-to-json-schema"
// 関数本体
export async function functionCalling<T extends Functions>(
  api: OpenAIApi,
  options: Options<T>
): Promise<Response<T>[]> {
  const functions = Object.entries(options.functions).map(([key, func]) => {
    // Zod スキーマから JSON Schema を自動生成する
    const jsonSchema = zodToJsonSchema(func.parameterSchema, "schema")
      .definitions?.schema
    return {
      name: key,
      description: func.description,
      parameters: jsonSchema,
    }
  })
  const response = await api.createChatCompletion({ ...options, functions })
  return response.data.choices.map((choice) => {
    const message = choice.message
    if (
      options.functions !== undefined &&
      message?.function_call?.name != undefined &&
      message.function_call.arguments != undefined
    ) {
      const { name, arguments: args } = message.function_call
      const func = options.functions[name]
      // JSON 文字列のパースとバリデーションも Zod にお任せ
      const parsedArgs = func.parameterSchema.parse(JSON.parse(args))
      return { type: "function_call", name: name, arguments: parsedArgs }
    } else {
      return { type: "message", content: message?.content ?? "" }
    }
  })
}
// これより下はただの型パズル
// 関数の名前、説明文、スキーマを表す型
// {
//    add: {
//      description: "Add two numbers together",
//      parameterSchema: z.object({ x: z.number(), y: z.number() })
//    },
// };
export type Functions = Record<string, FunctionSchema<any>>
interface Options<T extends Functions>
  extends Omit<CreateChatCompletionRequest, "functions"> {
  functions: T
}
type Response<T extends Functions> =
  | PlainResponse
  | ({ type: "function_call" } & FunctionCall<T>)
interface PlainResponse {
  type: "message"
  content: string
}
// Function Calling の結果の関数名と引数を表す型
// {name: "add", arguments: {x: "number", y: "number}}
type FunctionCall<T extends Functions> = {
  [K in keyof T]: {
    name: K
    arguments: FunctionArguments<T[K]>
  }
}[keyof T]
type FunctionSchema<T extends ZodRawShape> = {
  description: string
  parameterSchema: z.ZodObject<T>
}
type FunctionArguments<T extends FunctionSchema<any>> = z.infer<
  T["parameterSchema"]
>