説明
このスクリプトは、OpenAIのAPIキーを環境変数からロードし、OpenAI APIに設定します。次に、ユーザーメッセージと一連の関数(この場合はSQLクエリを生成する関数)を引数としてチャットコンプリーションを作成します。
チャットの結果から、メッセージ内の関数呼び出しを確認します。その関数名が"sql_query"である場合、その引数を抽出します。すべての必要な引数が存在することを確認します。いずれかの引数が欠けていたり、関数呼び出しが存在しなかった場合、エラーが発生します。
必要な引数がすべて存在し、関数呼び出しが正しい場合、関数の引数(SQLコマンドとその種類)を取得し、それを出力します。
使用ライブラリ
openai: OpenAIのGPTモデルを使用するためのライブラリ。
os: 環境変数を読み込むためのライブラリ。
dotenv: .envファイルから環境変数を読み込むためのライブラリ。
json: JSONを解析するためのライブラリ。
import openai
import os
from dotenv import load_dotenv
import json
def load_api_key():
"""
環境変数からOpen AIのAPIキーを読み込み、Open AIのAPIに設定します。
"""
load_dotenv()
openai.api_key = os.getenv("OPEN_AI_API")
def create_chat(user_msg, functions):
"""
チャット機能のリストを作成します。
:param function_list: 関数のリスト
:type function_list: list
:return: 関数のリスト
:rtype: list
"""
return openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_msg},
],
functions=functions,
function_call="auto",
)
def check_function_call(message):
"""
メッセージから関数呼び出しをチェックします。
:param message: メッセージ
:type message: str
:return: 関数呼び出しの結果が含まれる場合はSQL、含まれない場合はエラー
:rtype: dict | Exception
"""
if message.get("function_call"):
function_name = message["function_call"]["name"]
if function_name == "sql_query":
stringified_args = message["function_call"].get("arguments")
if stringified_args: # Check if stringified_args isn't None
function_args = json.loads(stringified_args) # Convert the stringified JSON to a Python dictionary
# Get the arguments from the function_args dictionary
command = function_args.get("command")
kind = function_args.get("kind")
# If any of the necessary arguments is None we raise an exception
if None in [command, kind]:
raise Exception("Function arguments are missing")
return function_args
else:
raise Exception("Function arguments are missing")
else:
raise Exception("Function call is missing")
def sql_query():
"""
SQLのDICTを作成します。
:return: SQLのDICT
:rtype: dict
"""
return {
"name": "sql_query",
"description": "Create a SQL query",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The full command of SQL",
},
"kind": {
"type": "string",
"description": "The name of the user",
"enum": ["SELECT", "INSERT", "UPDATE", "DELETE"]
},
},
"required": ["command", "kind"],
}
}
def main():
"""
APIキーの読み込み、チャット機能の作成、チャットコンプリーションの作成、関数呼び出しのチェックという一連の処理を実行します。
functions機能を理解しやすくするためにこのコードではユーザ情報の辞書を取得するのが目的となっています。
"""
load_api_key()
user_msg = "Could you generate a SQL query to retrieve data from the user table? The query should return a maximum of 10 records and only include the youngest users."
function_list = [sql_query()] # 初期の関数リスト
completion = create_chat(user_msg, function_list)
message = completion.choices[0]["message"]
# SQLの辞書を取得する
print(check_function_call(message))
# SQLコマンド
print(check_function_call(message)["command"])
# SQLの種類
print(check_function_call(message)["kind"])
if __name__ == "__main__":
main()