1
2

OpenAIのEmbeddingsを使ってみた

Posted at

背景

OpenAIが最近いろいろリリースされたけど、新しいEmbeddingsモデルについてはあまり知られていないようなので、ちょっとドキュメントに沿って、モデルを使って・クラスタリング・可視化までやってみました。

そもそもEmbeddingsとは

OpenAIの説明により、テキストを数字のベクトルに変換して、さまざまな統計や可視化に使うことができます。

Screenshot from 2024-02-23 19-24-18.png

実際に使ってみる

参考はこちらのドキュメントです。

Embeddingsモデルを使う

get_embedding の関数を定義

import pandas as pd
from openai import OpenAI
from typing import List, Optional

client = OpenAI()

def get_embedding(text: str, model="text-embedding-3-small", **kwargs) -> List[float]:
  text = text.replace("\n", " ")
  response = client.embeddings.create(input=[text], model=model, **kwargs)
  return response.data[0].embedding

今回はネットで金融ニュースを使います。

i_list = []
for i in text.split("\n"):
  i = i[11:].replace(" ","")
  i_list.append(i[:(i.find('['))])
print(len(i_list))
i_dict = {"text": i_list}
df = pd.DataFrame(i_dict)
print(df)

内容の一部はこのようになります。

Screenshot from 2024-02-23 19-30-19.png

次はベクトルに変換します。

embedding_model = "text-embedding-3-large"
df["embedding"] = df["text"].apply(lambda x: get_embedding(x, model=embedding_model))
print(df)

Screenshot from 2024-02-23 19-32-22.png

クラスタリング・可視化

import numpy as np

matrix = np.vstack(df.embedding.values)

from sklearn.cluster import KMeans

n_clusters = 3

kmeans = KMeans(n_clusters=n_clusters, init="k-means++", random_state=42)
kmeans.fit(matrix)
labels = kmeans.labels_
df["Cluster"] = labels

from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt

tsne = TSNE(n_components=2, perplexity=15, random_state=42, init="random", learning_rate=100)
vis_dims2 = tsne.fit_transform(matrix)

x = [x for x, y in vis_dims2]
y = [y for x, y in vis_dims2]

for category, color in enumerate(["purple", "green", "red"]):
    xs = np.array(x)[df.Cluster == category]
    ys = np.array(y)[df.Cluster == category]
    plt.scatter(xs, ys, color=color, alpha=0.3)

    avg_x = xs.mean()
    avg_y = ys.mean()

    plt.scatter(avg_x, avg_y, marker="x", color=color, s=100)
plt.title("Clusters identified visualized in language 2d using t-SNE")

結果画像はこちらになります。

download.png

少し中身を見てみます。

rev_per_cluster = 10
for j in range(n_clusters):
  print(j)
  sample_cluster_rows = df[df.Cluster == j]
  for k in range(rev_per_cluster):
    print(sample_cluster_rows.text.values[k])

かなりわかりやすくて、
0は為替関連で、
1は主に日本とアジア関連で、
2は主に日本とアメリカ関連という結果ですね。

Screenshot from 2024-02-23 19-37-08.png

ソースコード

最後に、GitHubにアップロードしたノートブックも載せます。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2