LoginSignup
11
4

More than 1 year has passed since last update.

OpenAIのEmbeddings APIをコードに対して使用した時の調査

Last updated at Posted at 2023-03-13

OpenAIのEmbeddings APIをコードに対して使用した時に、COS類似度がどのように変化するのか調査しました。

Embeddings APIは「検索」「クラスタリング」「レコメンデーション」「異常検出」「多様性測定」「分類」を行うために、テキストから潜在情報である1536次元のベクトルデータへ変換するAPIです。二つのテキストの内容が近しい場合、変換後のベクトルデータの類似度も近しくなります。また変換後のベクトルデータを加減算することで、二つ以上のテキストを組み合わせた検索に応用することもできます。

GPTを外部データソースと組み合わせて使う際に便利なライブラリであるLlamaIndexEmbeddings APIによって得られたベクトルを用いて検索用のインデックスを構成しており、個人的には注目度が高いです。

今回はこのEmbeddings APIを使ってPythonコードをベクトルに変換し、COS類似度を計算するとどのようになるか簡単に調査しました。

調査用プログラム

調査用のPythonプログラムはこちらです。
テキストファイルを読みだしてベクトルに変換後、いろいろな組み合わせでCOS類似度を計算しています。

import os
from pathlib import Path

import openai
import numpy as np


def embedding(file_path: str) -> list[float]:
    text = Path(file_path).read_text(encoding="utf-8")
    result = openai.Embedding.create(input=text, model="text-embedding-ada-002")
    if isinstance(result, dict):
        embedding = result["data"][0]["embedding"]
        return embedding
    return []


def cos_sim(a, b) -> float:
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def main():
    openai.api_key = os.getenv("OPENAI_API_KEY", "")

    a = np.array(embedding("data/a.py"))
    b = np.array(embedding("data/b.py"))
    c = np.array(embedding("data/c.py"))
    d = np.array(embedding("data/d.py"))

    ab = np.array(embedding("data/ab.py"))
    ba = np.array(embedding("data/ba.py"))
    cd = np.array(embedding("data/cd.py"))
    dc = np.array(embedding("data/dc.py"))

    print("a", "a", cos_sim(a, a))
    print("a", "b", cos_sim(a, b))
    print("a", "c", cos_sim(a, c))
    print("a", "d", cos_sim(a, d))
    print()

    print("a", "ab", cos_sim(a, ab))
    print("a", "ba", cos_sim(a, ba))
    print("a", "cd", cos_sim(a, cd))
    print("a", "dc", cos_sim(a, dc))
    print()

    print("(a + b) / 2", "ab", cos_sim((a + b) / 2, ab))
    print("(a + b) / 2", "ba", cos_sim((a + b) / 2, ba))
    print("(a + b) / 2", "cd", cos_sim((a + b) / 2, cd))
    print("(a + b) / 2", "dc", cos_sim((a + b) / 2, dc))
    print()

    print("ab - (b / 2)", "a", cos_sim(ab - (b / 2), a))
    print("ab - (b / 2)", "b", cos_sim(ab - (b / 2), b))
    print("ab - (b / 2)", "c", cos_sim(ab - (b / 2), c))
    print("ab - (b / 2)", "d", cos_sim(ab - (b / 2), d))
    print()


if __name__ == "__main__":
    main()

変換対象のプログラム

以下、ベクトルに変換する対象のプログラム群です。これらのコードはOpenAIのGPT Playgroundにて生成しました。
関数名を見れば何のアルゴリズムのコードか分かるかと思いますので内容の解説は省略します。

a.py
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quick_sort(left) + middle + quick_sort(right)
b.py
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(n - i - 1):
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]
    return arr
c.py
def binary_search(list, item):
    low = 0
    high = len(list) - 1

    while low <= high:
        mid = (low + high) // 2
        guess = list[mid]
        if guess == item:
            return mid
        if guess > item:
            high = mid - 1
        else:
            low = mid + 1
    return None
d.py
def linear_search(list, target):
    for i in range(len(list)):
        if list[i] == target:
            return i
    return -1

ここから先のコードは二つの関数を組み合わせたコードです。部分的な検索を行うときに類似度が上がるか否か検証するために作成しました。

ab.py
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quick_sort(left) + middle + quick_sort(right)


def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(n - i - 1):
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]
    return arr
ba.py
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(n - i - 1):
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]
    return arr


def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    return quick_sort(left) + middle + quick_sort(right)
cd.py
def binary_search(list, item):
    low = 0
    high = len(list) - 1

    while low <= high:
        mid = (low + high) // 2
        guess = list[mid]
        if guess == item:
            return mid
        if guess > item:
            high = mid - 1
        else:
            low = mid + 1
    return None


def linear_search(list, target):
    for i in range(len(list)):
        if list[i] == target:
            return i
    return -1
dc.py
def linear_search(list, target):
    for i in range(len(list)):
        if list[i] == target:
            return i
    return -1


def binary_search(list, item):
    low = 0
    high = len(list) - 1

    while low <= high:
        mid = (low + high) // 2
        guess = list[mid]
        if guess == item:
            return mid
        if guess > item:
            high = mid - 1
        else:
            low = mid + 1
    return None

実行結果

調査用プログラムの実行結果は以下の通りです。

a a 0.9999999999999999
a b 0.8893355008017317
a c 0.7410371898892079
a d 0.6967513453149546

a ab 0.9723891833395283
a ba 0.952633415197992
a cd 0.7286053077164385
a dc 0.7153595162242117

(a + b) / 2 ab 0.9818783009545884
(a + b) / 2 ba 0.9879698306320613
(a + b) / 2 cd 0.7372693406962301
(a + b) / 2 dc 0.7277244913020819

ab - (b / 2) a 0.9421566329777432
ab - (b / 2) b 0.7788764126830712
ab - (b / 2) c 0.6641259303997035
ab - (b / 2) d 0.6146168015547571

考察

当初想定した通りのCOS類似度が返ってきました。想定ポイントは以下の通りです。

  • 同じ「ソート」アルゴリズムのコードの場合は類似度が高い (a b 0.8893004177507372)
  • 部分的に同じコードが含まれている場合は類似度が高い (a ab 0.9723046370699587)
    • なおかつ、コード内の順序は関係ない
  • 二つのベクトルの平均を取った場合、同じ組み合わせのコードと類似度が高い ((a + b) / 2 ab 0.9818783009545884)
  • ベクトルの差を取った場合、残ったコードと類似度が高い (ab - (b / 2) a 0.9421566329777432)

同じ「ソート」アルゴリズムのコードの場合は類似度が高い

クイックソートとバブルソートのEmbeddingを比較すると相対的に見て類似度が高くなるという結果になりました。この結果は以下のワークフローに応用することを想定しています。

  • コードの概要を書く
  • 概要を基にGPTでコードを生成する
  • 生成したコードをベクトルに変換する
  • ベクトルを基にリポジトリ内の類似コードを検索する

なお、この結果は想定していたものではあるものの精度に関しては少々不安でした。そのため、実利用の際は様々なコードで試してみる必要があると考えています。

部分的に同じコードが含まれている場合は類似度が高い

a 関数を含むコードを、a 関数と b 関数両方含むコードに対して類似度を計算したところ高くなりました。これは部分的なコード検索に使用することを想定しています。但し a 関数が微妙に異なるパターンや、検索対象のコードがかなり大きな場合の対応については調査、考察が必要です。

二つのベクトルの平均を取った場合、同じ組み合わせのコードと類似度が高い

ベクトルの平均演算が適切に機能するか確認したところ、想定通り類似度が高くなりました。これにより複数のコードの断片であっても部分検索に応用できそうだということが分かりました。

ベクトルの差を取った場合、残ったコードと類似度が高い

念のためベクトルの減算が上手くいくかどうか試したところ、想定通り類似度が高くなりました。
ab ベクトルから b ベクトルを引くと a ベクトルのみが残るはずだ、という想定の元の実験をしたところ、a.py の類似度が最も高くなりました。なおEmbeddings APIはベクトルの大きさが1になるよう正規化するため、減算する時はbを2で割ることでスケールを合わせています。

まとめ

ということで、今のところは全て想定通りの結果になりました。
Embeddings APIをソースコードに対して使うのは初めてで、どのような特性なのか、類似度はどの程度の精度なのか気になったため調査を行いました。結果としては、ごく小さなコードであれば自然文と同じような感覚で使えそうだということが分かりました。

11
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
4