0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLM経由でなぜPythonを実行するのか(結論:計算が苦手)

Last updated at Posted at 2024-08-03

LLMの制約事項

  1. 最新の情報にアクセスできない(モデルを生成した際の学習データに依存する)
  2. 外部のシステムにアクセスできない(ウェブアクセス)
  3. 計算が苦手

今回は3つめの計算が苦手な点にフォーカスして調査した結果を共有する。

計算が苦手

  • OpenAI社のGPT3.5~4で掛け算をさせた結果を以下に示す
  • xとyが掛け合わせる数字を表しており、結果が正しい場合は青、間違っている場合は赤の点でプロットしている

gpt-3.5-turbo-1106

image.png

  • 正解率:70.1%

gpt-4o-mini

image.png

  • 正解率:74.8%

gpt-4

image.png

  • 正解率: 77.5%

gpt-4-turbo

image.png

  • 正解率: 88.2%

gpt-4o

image.png

  • 正解率:88.5%

まとめ

モデル 正解率
gpt-3.5-turbo-1106 70.1%
gpt-4o-mini 74.8%
gpt-4 77.5%
gpt-4-turbo 88.2%
gpt-4o 88.5%
  • 2桁以下の掛け算はどのモデルも結果は正確だが、3桁同士からは不正確になる
  • gpt-4oは計算の過程があり、CoTが用いられているため、優秀だが4桁同士は正解率が悪くなる

テスト手法

  1. ランダムなxとyを生成する
  2. システムプロンプトにx × y = を入れてOpenAIのAPIに投げる
  3. OpenAIのAPIの文字列からすべての数字を列挙し、正解に最も近い数字を抽出する
  4. 抽出した数字が正しいかどうかを判定する
  • 末尾のソースコード参照

計算ミスを克服する方法

ChatGPTを含む計算が必要とされるLLMアプリでは以下の方法を用いている

  1. LLMに計算をさせるためのPythonをコードを吐かせる
  2. Pythonを実行し、結果を取り出す

ChatGPTの例

image.png

  • 分析と書いているのはPythonのコードを生成して実行していることを意味する

脆弱性

このようなPythonコードを生成して実行するシステムは、プロンプトインジェクションにより、任意のPythonコードが実行でき、RCE(任意のコード実行)の脆弱性に繋がる可能性が高い。

ChatGPTの例

image.png

  • なお、ChatGPTはPythonの実行環境をサンドボックス化しており、任意のコードを実行しても影響はない

対策

ChatGPT同様にPython実行環境をサンドボックス化したり、実行可能な関数を制限することが必要となる。

付録:テスト用コード

  • 利用する場合はOPENAI_API_KEY = "sk-proj-xxxxxx"を変更すること
import re
import csv
import random
import requests
import math
import pandas as pd
import matplotlib.pyplot as plt


OPENAI_API_KEY = "sk-proj-xxxxxx"
models=["gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini", "gpt-4-turbo", "gpt-4o"]

def query_openai(messages, model):
  url="https://api.openai.com/v1/chat/completions"
  headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {OPENAI_API_KEY}"
  }
  data = {
    "model": model,
    "messages": messages,
    "temperature": 1,
    "n": 1,
  }
  try:
    response = requests.post("https://api.openai.com/v1/chat/completions", json=data, headers=headers)
    response.raise_for_status()  # Raises an exception for 4XX or 5XX errors
    return {"success": True, "result": response.json()}  # Return the response data as JSON
  except requests.exceptions.HTTPError as http_err:
    return {"success": False, "result": f"HTTP error occurred: {http_err}"}
  except requests.exceptions.ConnectionError as conn_err:
    return {"success": False, "result": f"Connection error occurred: {conn_err}"}
  except requests.exceptions.Timeout as timeout_err:
    return {"success": False, "result": f"Timeout error occurred: {timeout_err}"}
  except requests.exceptions.RequestException as req_err:
    return {"success": False, "result": f"An error occurred: {req_err}"}


def check_multi(x, y, answer):
  correct_result = x * y
  extracted_numbers = re.findall(r'\d+', answer.replace(',', ''))
  if extracted_numbers:
    extracted_results = [int(num) for num in extracted_numbers]
    closest_result = min(extracted_results, key=lambda num: abs(correct_result - num))
    correct_flag = closest_result == correct_result
    difference = abs(correct_result - closest_result)
  else:
    correct_flag = False
    difference = correct_result
  return correct_flag, closest_result, difference


def generate_log_scale_random(min_value, max_value):
    log_min = math.log(min_value)
    log_max = math.log(max_value)
    return int(math.exp(random.uniform(log_min, log_max)))

def create_system_prompt(x, y):
    return f"{x} × {y} = "


def plot_log_scatter_v2(csv_filename, output_filename):
    df = pd.read_csv(csv_filename)
    df = df[['x', 'y', 'correct']]
    plt.figure(figsize=(10, 6))
    for _, row in df.iterrows():
        color = 'blue' if row['correct'] else 'red'
        plt.scatter(row['x'], row['y'], color=color)
    
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('x (log scale)')
    plt.ylabel('y (log scale)')
    plt.title(f'Multiplication Test (Blue: Correct, Red: Incorrect)')
    plt.grid(True, which="both", ls="--")
    plt.savefig(output_filename)
    plt.close()



if __name__ == "__main__":
    for model in models:
      with open(f'multiplication_{model}.csv', mode='w', newline='') as file:
        columns = ['x', 'y', 'x times y', 'gpt_answer', 'correct', 'diff']
        writer = csv.DictWriter(file, fieldnames=columns)
        writer.writeheader()
        for i in range(1000):
          x = generate_log_scale_random(1, 9999)
          y = generate_log_scale_random(1, 9999)
          messages = [{
              "role": "system",
              "content": create_system_prompt(x, y)
          }]
          openai_api_result = query_openai(messages, model)
          if not openai_api_result["success"]:
              print(f"[!] Error {openai_api_result['result']}")
              exit()
          else:
              answer = openai_api_result["result"]["choices"][0]["message"]["content"]
              correct, gpt_answer, diff = check_multi(x, y, answer)
              print(f"{model}({i}/1000) {x} x {y} = {answer} -> GPT Answer: {gpt_answer}(diff:{diff})")
              row = {
                  'x': x,
                  'y': y,
                  'x times y': x * y,
                  'gpt_answer': gpt_answer,
                  'correct': correct,
                  'diff': diff
              }
              writer.writerow(row)
      plot_log_scatter_v2(f'multiplication_{model}.csv', f"multiplication_{model}.png")
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?