LLMの制約事項
- 最新の情報にアクセスできない(モデルを生成した際の学習データに依存する)
- 外部のシステムにアクセスできない(ウェブアクセス)
- 計算が苦手
今回は3つめの計算が苦手な点にフォーカスして調査した結果を共有する。
計算が苦手
- OpenAI社のGPT3.5~4で掛け算をさせた結果を以下に示す
- xとyが掛け合わせる数字を表しており、結果が正しい場合は青、間違っている場合は赤の点でプロットしている
gpt-3.5-turbo-1106
- 正解率:70.1%
gpt-4o-mini
- 正解率:74.8%
gpt-4
- 正解率: 77.5%
gpt-4-turbo
- 正解率: 88.2%
gpt-4o
- 正解率:88.5%
まとめ
モデル | 正解率 |
---|---|
gpt-3.5-turbo-1106 | 70.1% |
gpt-4o-mini | 74.8% |
gpt-4 | 77.5% |
gpt-4-turbo | 88.2% |
gpt-4o | 88.5% |
- 2桁以下の掛け算はどのモデルも結果は正確だが、3桁同士からは不正確になる
-
gpt-4o
は計算の過程があり、CoTが用いられているため、優秀だが4桁同士は正解率が悪くなる
テスト手法
- ランダムなxとyを生成する
- システムプロンプトに
x × y =
を入れてOpenAIのAPIに投げる - OpenAIのAPIの文字列からすべての数字を列挙し、正解に最も近い数字を抽出する
- 抽出した数字が正しいかどうかを判定する
- 末尾のソースコード参照
計算ミスを克服する方法
ChatGPTを含む計算が必要とされるLLMアプリでは以下の方法を用いている
- LLMに計算をさせるためのPythonをコードを吐かせる
- Pythonを実行し、結果を取り出す
ChatGPTの例
- 分析と書いているのはPythonのコードを生成して実行していることを意味する
脆弱性
このようなPythonコードを生成して実行するシステムは、プロンプトインジェクションにより、任意のPythonコードが実行でき、RCE(任意のコード実行)の脆弱性に繋がる可能性が高い。
ChatGPTの例
- なお、ChatGPTはPythonの実行環境をサンドボックス化しており、任意のコードを実行しても影響はない
対策
ChatGPT同様にPython実行環境をサンドボックス化したり、実行可能な関数を制限することが必要となる。
付録:テスト用コード
- 利用する場合は
OPENAI_API_KEY = "sk-proj-xxxxxx"
を変更すること
import re
import csv
import random
import requests
import math
import pandas as pd
import matplotlib.pyplot as plt
OPENAI_API_KEY = "sk-proj-xxxxxx"
models=["gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini", "gpt-4-turbo", "gpt-4o"]
def query_openai(messages, model):
url="https://api.openai.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
data = {
"model": model,
"messages": messages,
"temperature": 1,
"n": 1,
}
try:
response = requests.post("https://api.openai.com/v1/chat/completions", json=data, headers=headers)
response.raise_for_status() # Raises an exception for 4XX or 5XX errors
return {"success": True, "result": response.json()} # Return the response data as JSON
except requests.exceptions.HTTPError as http_err:
return {"success": False, "result": f"HTTP error occurred: {http_err}"}
except requests.exceptions.ConnectionError as conn_err:
return {"success": False, "result": f"Connection error occurred: {conn_err}"}
except requests.exceptions.Timeout as timeout_err:
return {"success": False, "result": f"Timeout error occurred: {timeout_err}"}
except requests.exceptions.RequestException as req_err:
return {"success": False, "result": f"An error occurred: {req_err}"}
def check_multi(x, y, answer):
correct_result = x * y
extracted_numbers = re.findall(r'\d+', answer.replace(',', ''))
if extracted_numbers:
extracted_results = [int(num) for num in extracted_numbers]
closest_result = min(extracted_results, key=lambda num: abs(correct_result - num))
correct_flag = closest_result == correct_result
difference = abs(correct_result - closest_result)
else:
correct_flag = False
difference = correct_result
return correct_flag, closest_result, difference
def generate_log_scale_random(min_value, max_value):
log_min = math.log(min_value)
log_max = math.log(max_value)
return int(math.exp(random.uniform(log_min, log_max)))
def create_system_prompt(x, y):
return f"{x} × {y} = "
def plot_log_scatter_v2(csv_filename, output_filename):
df = pd.read_csv(csv_filename)
df = df[['x', 'y', 'correct']]
plt.figure(figsize=(10, 6))
for _, row in df.iterrows():
color = 'blue' if row['correct'] else 'red'
plt.scatter(row['x'], row['y'], color=color)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('x (log scale)')
plt.ylabel('y (log scale)')
plt.title(f'Multiplication Test (Blue: Correct, Red: Incorrect)')
plt.grid(True, which="both", ls="--")
plt.savefig(output_filename)
plt.close()
if __name__ == "__main__":
for model in models:
with open(f'multiplication_{model}.csv', mode='w', newline='') as file:
columns = ['x', 'y', 'x times y', 'gpt_answer', 'correct', 'diff']
writer = csv.DictWriter(file, fieldnames=columns)
writer.writeheader()
for i in range(1000):
x = generate_log_scale_random(1, 9999)
y = generate_log_scale_random(1, 9999)
messages = [{
"role": "system",
"content": create_system_prompt(x, y)
}]
openai_api_result = query_openai(messages, model)
if not openai_api_result["success"]:
print(f"[!] Error {openai_api_result['result']}")
exit()
else:
answer = openai_api_result["result"]["choices"][0]["message"]["content"]
correct, gpt_answer, diff = check_multi(x, y, answer)
print(f"{model}({i}/1000) {x} x {y} = {answer} -> GPT Answer: {gpt_answer}(diff:{diff})")
row = {
'x': x,
'y': y,
'x times y': x * y,
'gpt_answer': gpt_answer,
'correct': correct,
'diff': diff
}
writer.writerow(row)
plot_log_scatter_v2(f'multiplication_{model}.csv', f"multiplication_{model}.png")