背景
OpenAIが最近いろいろリリースされたけど、新しいEmbeddingsモデルについてはあまり知られていないようなので、ちょっとドキュメントに沿って、モデルを使って・クラスタリング・可視化までやってみました。
そもそもEmbeddingsとは
OpenAIの説明により、テキストを数字のベクトルに変換して、さまざまな統計や可視化に使うことができます。
実際に使ってみる
参考はこちらのドキュメントです。
Embeddingsモデルを使う
get_embedding
の関数を定義
import pandas as pd
from openai import OpenAI
from typing import List, Optional
client = OpenAI()
def get_embedding(text: str, model="text-embedding-3-small", **kwargs) -> List[float]:
text = text.replace("\n", " ")
response = client.embeddings.create(input=[text], model=model, **kwargs)
return response.data[0].embedding
今回はネットで金融ニュースを使います。
i_list = []
for i in text.split("\n"):
i = i[11:].replace(" ","")
i_list.append(i[:(i.find('['))])
print(len(i_list))
i_dict = {"text": i_list}
df = pd.DataFrame(i_dict)
print(df)
内容の一部はこのようになります。
次はベクトルに変換します。
embedding_model = "text-embedding-3-large"
df["embedding"] = df["text"].apply(lambda x: get_embedding(x, model=embedding_model))
print(df)
クラスタリング・可視化
import numpy as np
matrix = np.vstack(df.embedding.values)
from sklearn.cluster import KMeans
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, init="k-means++", random_state=42)
kmeans.fit(matrix)
labels = kmeans.labels_
df["Cluster"] = labels
from sklearn.manifold import TSNE
import matplotlib
import matplotlib.pyplot as plt
tsne = TSNE(n_components=2, perplexity=15, random_state=42, init="random", learning_rate=100)
vis_dims2 = tsne.fit_transform(matrix)
x = [x for x, y in vis_dims2]
y = [y for x, y in vis_dims2]
for category, color in enumerate(["purple", "green", "red"]):
xs = np.array(x)[df.Cluster == category]
ys = np.array(y)[df.Cluster == category]
plt.scatter(xs, ys, color=color, alpha=0.3)
avg_x = xs.mean()
avg_y = ys.mean()
plt.scatter(avg_x, avg_y, marker="x", color=color, s=100)
plt.title("Clusters identified visualized in language 2d using t-SNE")
結果画像はこちらになります。
少し中身を見てみます。
rev_per_cluster = 10
for j in range(n_clusters):
print(j)
sample_cluster_rows = df[df.Cluster == j]
for k in range(rev_per_cluster):
print(sample_cluster_rows.text.values[k])
かなりわかりやすくて、
0は為替関連で、
1は主に日本とアジア関連で、
2は主に日本とアメリカ関連という結果ですね。
ソースコード
最後に、GitHubにアップロードしたノートブックも載せます。