概要
テキストデータの埋め込みをしたら、似ていない記事同士でもcos類似度が結構高い...という事象に出くわすことがあると思います。
今回はそれを解消する手法の紹介とpythonでの実装をしてみました。
問題
RAG等で、埋め込み(テキストをベクトルに変換すること)を使うことがあると思います。ベクトルになるので、テキスト同士がどのくらい似ているかを定量的に計算することができ、便利だと思います。
しかし、埋め込みしたテキスト同士の類似性で疑問を持ったことがありませんか?
「似ていないテキストなのに、cos類似度が結構高い。。。」
実は、異方性と呼ばれる問題です。
この方の説明がわかりやすいのですが、BERTなどのLLMの埋め込みベクトルは同じ方向に埋め込みしがち、という問題を抱えているようです。
結果として、cos類似度などでテキスト同士の類似度を計測した時に、数値が結構大きく出てしまう、ということが起きます。
▼異方性のイメージ:グレーの矢印方向に埋め込みされやすいため、例えば、"魚"と"車"というあまり似ていない(はずの)単語があった場合、下図の⇔方向にembedding(ベクトル)が作られやすくなります。そうすると、赤矢印で記載しているように、cos類似度だと結構大きな値(例えば0.7)が出てきてしまうのです(cos類似度は赤矢印の間の角度が小さいほど似ていると判断する指標のため)。
これを解消する方法として、いくつかの改善方法が提示されています。
筆者が調べた限りの範囲では以下の3つがありました。
- ①BERT-flow:
- ②BERT-whitening:
- ③SimCSE:
このうち、①②はBERT(というか、LLM系のembeddingモデル)のベクトルを後処理によって異方性の解消を試みる、という方向性の手法です。③は、BERT自体をファインチューニングすることで、異方性を解消する方法です。
どれがベスト、ということは無く、一長一短です。
ただ、③はファインチューニング用のデータの準備が必要であり、質と量の高さが求められる点で構築難易度が高いです。
①②は後処理的に行うので、データを用意する必要性はないというメリットがあります。
そこで今回は比較的簡単に構築できそうな、②のBERT-whiteningという手法を試してみることとしました。
結論から言えば、そこそこ上手くいきました。
BERT-whiteningのベースアイデア
超ざっくり言えば、PCAです。
データの分散が最大となる方向に新しい座標(基底、固有ベクトル)を作り出し、その新しい座標空間で、改めてcos類似度をとります。
さきほどの図で言えば、平均ゼロとなるように平行移動し、分散最大の方向となるベクトルを見つけます。そのベクトルが張る空間で改めて"魚"と"車"のベクトルを表現します。
その新しいベクトル同士でcos類似度を計算します。
そうすると、さっきまでは魚∠車が大体30度くらいだったのに、新しい座標で見てみると、魚∠車が90度近くになりました。
cos類似度的にはかなり小さくなりました。つまり、「魚と車は似ていない」という直感により近い形で表現できるようになりました。
くどいようですが、今までだと、cos(魚∠車)=0.7となってしまい、「え?こんなに高いの?」と思ってしまうようなテキストでした。それが新しい座標では、cos(魚∠車)=0.05くらいになっており、感覚的な"似ていない度合"と一致するようになりました。
実装
今回は、livedoorのニュースコーパスを利用します。
Google Colabで作っていますが、Notebookはこちらのgitにおいてます。
https://github.com/KENTAROSZK/bert-whitening
ライブラリのインストールとその他設定など
%pip install transformers
%pip install fugashi
%pip install ipadic
%pip install unidic-lite
%pip install japanize-matplotlib
# 日本語が使えるようにしておく
import japanize_matplotlib
japanize_matplotlib.japanize()
SOURCE_DIR = "./data/"
データのDLとpandasDataFrameへの格納
# データのDLと解凍
# livedoorデータをDLしてなければ、DLする
import os
if not os.path.exists(SOURCE_DIR + "ldcc-20140209.tar.gz"):
import subprocess, shlex
args = shlex.split('wget -P data/ "https://www.rondhuit.com/download/ldcc-20140209.tar.gz"')
ret = subprocess.run(args, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
print(ret.returncode) # 終了コード。'0'であれば正常終了
# 解凍済みのフォルダが無ければ解凍する
if not os.path.exists(SOURCE_DIR + "livedoor"):
import tarfile
tar = tarfile.open(SOURCE_DIR + "ldcc-20140209.tar.gz", 'r:gz')
tar.extractall(SOURCE_DIR + "livedoor")
print("extract done")
# 結構時間かかるので、ファイルの有無を確認して、
# 無い場合のみデータの読み込み処理をさせる形にする
import re
import pickle
import pandas as pd
# カテゴリーのフォルダ名のみを抽出
categories = [name for name in os.listdir(
SOURCE_DIR + "livedoor/text/") if os.path.isdir(SOURCE_DIR + "livedoor/text/"+name)]
print(f"{len(categories)=}\n", f"{categories=}")
is_list_exist = os.path.exists(SOURCE_DIR + "list_df.pkl")
if not is_list_exist:
list_df = []
for category in categories:
path = SOURCE_DIR + "livedoor/text/" + category + "/"
files = os.listdir(path)
files_file = [f for f in files if os.path.isfile(os.path.join(path, f)) and (re.match(category, f))]
for file_name in files_file:
file = path + file_name
with open(file) as text_file:
text = text_file.readlines()
list_df.append([category, text[2], text[3]])
#text[2] : # タイトル
#text[3] : # 本文
with open(SOURCE_DIR + 'list_df.pkl', mode='wb') as f:
pickle.dump(list_df, f)
else:
with open(SOURCE_DIR + 'list_df.pkl', mode='rb') as f:
list_df = pickle.load(f)
df = pd.DataFrame(list_df, columns = ["category", "title", "content"])
df.head()
検証用に記事を選定する
livedoorの記事は9カテゴリあり、各カテゴリで7,800個の記事がある。それらすべてを使うと、確認も難しいことから、一部の記事に限定して検証することとした。
私の主観で、異なるであろう3カテゴリと、その中から内容の方向性が異なる3記事(3×3)を選択した。選ぶときは1カテゴリの記事を一つずつ見ていき、似ていなさそうな記事を選ぶようにした(※sports-watchだけ、野球2記事、サッカー1記事ということであえて似ている記事も混ぜた)。
- dokujo-tsushin
- title:地方女子と東京女子の幸せの価値観\n
- title:今度生まれ変わるなら女or男どちらになりたい?\n
- title:親が倒れた時、独女は何ができるのか?\n
- it-life-hack
- title:iPhoneやiPadが目ざまし時計になるユニークなガジェット登場!\n
- title:高画質と小型・軽量化を両立!キヤノンのミラーレス一眼「EOS M」を見てきました\n
- title:最近使ったファイルを素早くスタートメニューから開く【知っ得・虎の巻】\n
- sports-watch
- title:阿部も頭を抱える原采配、ノムさんは「何にもするな、選手に任せておけ」\n
- title:巨人の拙劣な采配に批判殺到、ノムさんも「根拠がサッパリわからない」\n
- title:なでしこリーグに警鐘「リーグの意味がない」\n
# 選んだ記事のみを抽出して新しいDataFrameを作成する
# - dokujo-tsushin
# - title:地方女子と東京女子の幸せの価値観\n
# - title:今度生まれ変わるなら女or男どちらになりたい?\n
# - title:親が倒れた時、独女は何ができるのか?\n
# - it-life-hack
# - title:iPhoneやiPadが目ざまし時計になるユニークなガジェット登場!\n
# - title:高画質と小型・軽量化を両立!キヤノンのミラーレス一眼「EOS M」を見てきました\n
# - title:最近使ったファイルを素早くスタートメニューから開く【知っ得・虎の巻】\n
# - sports-watch
# - title:阿部も頭を抱える原采配、ノムさんは「何にもするな、選手に任せておけ」\n
# - title:巨人の拙劣な采配に批判殺到、ノムさんも「根拠がサッパリわからない」\n
# - title:なでしこリーグに警鐘「リーグの意味がない」\n
titles = [
"地方女子と東京女子の幸せの価値観\n",
"今度生まれ変わるなら女or男どちらになりたい?\n",
"親が倒れた時、独女は何ができるのか?\n",
"iPhoneやiPadが目ざまし時計になるユニークなガジェット登場!\n",
"高画質と小型・軽量化を両立!キヤノンのミラーレス一眼「EOS M」を見てきました\n",
"最近使ったファイルを素早くスタートメニューから開く【知っ得・虎の巻】\n",
"阿部も頭を抱える原采配、ノムさんは「何にもするな、選手に任せておけ」\n",
"巨人の拙劣な采配に批判殺到、ノムさんも「根拠がサッパリわからない」\n",
"なでしこリーグに警鐘「リーグの意味がない」\n",
]
df_input = df.loc[df["title"].isin(titles)]
display(df_input)
BERTは、512トークンまでしか読み込めないため、選んだ記事の文字数がそれをオーバーしていないか確認しておく。もしオーバーしていればそれ用の対策をする必要がある。
# 各contentの文字数をカウントしておく
# -------------------------------------------------
# BERTモデルでインプットできる最大トークン数は512トークンであるため、
# 各コンテンツの文字数がそれを超えていないかどうかを確認しておく.
# 文字数≠トークン数ではあるが、文字数数えれば、おおよそのトークン数は見積もれるため、
# 文字数をカウントする.
# -------------------------------------------------
df_input["content"].apply(lambda x: len(x))
結果、全て512トークンにならないことを確認できた。
tohoku-nlp/bert-base-japanese-v3
を使った埋め込みをする
# モデルの読み込み
# -------------------------------------------------
# 今回は東北大の日本語bertを使うこととした
# https://huggingface.co/tohoku-nlp/bert-base-japanese-v3
# これが、日本語のBERTを使うときのデファクトスタンダード的立ち位置のモデルだと思う
# -------------------------------------------------
from transformers import BertJapaneseTokenizer, BertModel
import torch
# トークナイザーとモデルの初期化
model_name = "tohoku-nlp/bert-base-japanese-v3"
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
テスト実行してみる。
def embed(text: str) -> torch:
tokens = tokenizer(text, return_tensors='pt')
# ベクトルの取得
with torch.no_grad():
outputs = model(**tokens)
# [CLS]トークンの出力を取得(文全体の表現として使用)
embedding = outputs.last_hidden_state[:, 0, :]
return embedding
# テストデータでの実験
text = "日本語のテキストを埋め込みベクトルに変換する。"
# 埋め込みベクトルの表示
embedding = embed(text)
print(embedding.shape)
print(embedding)
実行結果(ちゃんとebmeddingできている模様)
torch.Size([1, 768])
tensor([[-5.3307e-01, 3.1861e-01, -6.8248e-01, 9.5944e-02, -3.3479e-01,
1.0736e+00, -7.9251e-01, -2.1434e-01, -2.2504e-01, -6.3049e-01,
1.5728e-01, -1.2440e-01, -6.8706e-02, -2.2331e-01, 8.1621e-02,
5.9559e-01, 5.6205e-01, 3.7138e-01, -6.6897e-01, -1.1391e-01,
4.1173e-01, 4.9424e-02, -5.9612e-01, 5.0083e-01, -4.3349e-01,
2.5437e-01, -4.8640e-01, 2.2104e-01, 7.5307e-02, -3.5501e-01,
-1.7728e-01, -2.3050e-01, 1.0583e-01, -1.1085e+01, -6.5359e-01,
1.6122e-01, 4.4521e-01, 9.1183e-02, -1.3834e-01, 6.0131e-01,
-3.9717e-01, 6.3918e-02, -2.0906e-01, -5.9148e-02, 6.2810e-01,
1.9183e-01, -6.0775e-01, -1.4797e-01, -4.2523e-02, 5.6888e-02,
7.7806e-02, 4.3937e+00, -3.3134e-01, -5.4770e-01, 3.3916e-01,
-1.7471e-01, -6.4322e-01, 6.9874e-01, 6.6428e-01, 2.7927e-01,
-3.8730e-01, -1.6792e-01, 3.7313e-01, 7.7227e-01, -1.3252e-01,
1.5337e+00, -1.0588e+00, -3.9464e-01, -6.5427e-01, -4.6145e-01,
-6.0219e-01, -5.7748e-01, 3.3798e-01, -3.5778e-02, -1.4076e-01,
-3.1608e-01, -5.7309e-01, -2.6811e-01, -7.9586e-01, -2.4284e-01,
-2.8661e-01, -6.7430e-01, -3.4786e-01, -1.9938e-01, -5.9286e-01,
-6.9561e-02, 3.5561e-01, -5.3236e-01, -1.6805e-01, -3.8833e-01,
7.1183e-01, -2.7842e-01, 6.6965e-01, 5.0270e-01, -1.9478e-01,
-3.7238e-01, -4.8733e-01, 3.2830e-01, -7.5241e-02, 1.4339e-01,
-2.4186e-01, -8.4318e-01, -4.4852e-01, 2.4500e-01, 3.8053e-01,
6.0060e-01, -2.9478e-01, 1.3079e+00, -5.2703e-01, 7.3396e-01,
8.7623e-01, 2.5499e-01, -3.1758e-01, 4.7001e-01, -5.0686e-01,
-6.4036e-01, -8.7439e-02, 1.9865e-01, -2.7933e-01, 2.3291e-01,
-5.0975e-02, -2.1061e-01, -6.9201e-02, -7.7129e-02, 2.9466e-01,
-3.2640e-01, 2.1815e-01, 5.5657e-01, 4.2765e-02, 3.9038e-01,
-8.1045e-01, 4.6795e-01, -5.9212e-03, 4.6339e-02, -4.5700e-01,
-6.7404e-02, -1.4106e-03, -8.1576e-01, -7.4314e-02, -3.0102e-01,
1.3904e-01, -7.0643e-01, -5.3671e-01, -2.3077e-01, -4.5168e-01,
1.2185e-01, -1.9259e-01, -1.0720e+00, -4.7362e-01, 1.2510e+00,
-5.5709e-01, -8.2817e-03, 3.5699e-01, 2.6396e-01, 1.4472e-01,
2.1049e-01, -5.5835e-02, -2.7717e-01, 9.7728e-01, -7.1615e-01,
2.8362e-01, -1.9504e-01, -4.4900e-01, -2.4768e-01, 1.9616e-01,
-1.8515e-01, 2.5913e-02, -1.8008e-01, -6.4997e-01, 7.2332e-01,
1.0359e-01, 6.0998e-01, -2.7732e-01, 3.9309e-01, 5.5059e-01,
5.7893e-02, -2.8963e-01, -1.4359e-01, 5.1991e-01, 1.1901e+00,
8.9249e-02, 3.8274e-04, -9.2657e-03, 4.7826e-01, 1.4523e-01,
-2.9245e-01, 1.0270e-01, 1.4711e-01, 1.2334e-01, -5.3171e-01,
-8.8458e-01, 5.5324e-01, -2.1804e-02, 4.8823e-01, 1.1206e-02,
2.0076e-01, -5.0890e-01, -6.2055e-01, -7.8141e-02, -9.4742e-01,
5.5015e-01, -3.3745e-01, 5.7122e-01, 9.5395e-02, 4.2680e-01,
-3.2716e-01, 3.6107e-01, 3.0756e-01, 4.2148e-01, 3.4682e-01,
-6.3507e-01, -5.3360e-02, -1.9885e-02, 9.5817e-01, 4.3717e-01,
1.5631e-01, 3.0125e-01, 3.4838e-01, 2.3718e-02, 8.4373e-01,
3.9504e-01, 5.0342e-01, -2.0471e-02, -5.0410e-01, -1.1470e+00,
8.3954e-02, 5.5049e-01, 1.7311e-01, 3.5348e-01, 3.8224e-01,
-2.2781e-01, -3.1992e-01, 2.3970e-01, 5.1294e-01, 8.0557e-01,
-4.0761e-02, -4.2108e-01, -3.7161e-01, -2.4280e-01, 3.6341e-01,
6.8811e-01, -4.6937e-01, 2.3234e-01, 2.1915e-01, 8.6298e-02,
-3.0163e-01, 7.6827e-02, 2.2256e-01, -3.0967e-01, 2.9908e-01,
-6.0842e-01, -2.0075e-02, -1.8872e-01, 1.5650e-01, -1.8749e-01,
5.2185e-01, -1.1640e+00, -9.1943e-02, -5.5058e-01, -3.4478e-01,
4.4468e-01, -3.7807e-01, 6.3536e-01, -3.0874e-01, 1.0037e-01,
2.4224e-03, 3.3906e-01, -8.2079e-01, 4.0938e-01, 2.2595e-01,
5.9296e-01, -6.2593e-01, -3.0528e-01, -3.5744e-01, -1.6838e-02,
-6.8278e-02, -2.8923e-01, -3.3441e-01, -9.0121e-01, 2.4165e-01,
-1.1086e-01, 2.3615e-01, -7.6862e-02, -2.5599e-01, 2.2792e-01,
-2.8259e-02, -1.1505e-01, -3.7097e-01, -2.3394e-01, 2.0441e-01,
-2.2138e-01, 4.4121e-01, 9.2581e-03, 2.1912e-01, 2.2905e-01,
3.1738e-01, -4.4068e-01, -5.2371e-01, 2.2733e-01, 1.6074e-02,
-1.0591e+00, 2.6657e-01, 7.9358e-02, -5.8205e-01, -4.0098e-01,
-3.3972e-01, -3.4775e-01, -6.4009e-01, -5.1202e-01, 4.9761e-01,
-5.2391e-01, -7.2391e-01, 5.0904e-01, 5.1530e-02, -4.8038e-01,
-5.4059e-02, 3.5978e-01, -4.2937e-01, 6.7368e-01, -8.2849e-01,
-8.2977e-01, 1.7692e-01, 2.5559e-01, -4.6576e-01, 1.5941e-01,
-5.3035e-02, 1.4016e-02, -1.5104e+00, -3.3907e-01, -4.5578e-01,
1.8424e-01, -5.0961e-01, 9.3177e-02, -3.4389e-01, -8.1815e-01,
-1.8131e-01, 4.9059e-01, -2.1245e-01, 3.6567e-01, 9.5418e-03,
-1.2042e-01, -1.7827e-01, 8.4794e-02, 6.0203e-02, -1.7994e-01,
1.5910e-01, 1.9343e-02, -1.9241e-02, 3.1942e-01, 4.4005e-01,
-3.1631e-01, -2.1627e-01, -6.4677e-01, -8.6953e-01, -1.5044e-01,
-3.1985e-01, -5.5127e-01, 5.7671e-01, -6.5993e-01, 9.5905e-02,
7.5092e-02, 5.7427e-01, -1.9826e-01, 1.1741e-02, -2.9668e-01,
7.7212e-02, 3.7215e-01, -1.0968e-01, -3.0709e-01, 5.7729e-01,
1.2661e-01, -7.5466e-02, -5.5822e-01, 2.8192e-01, 1.6176e-01,
-4.6218e-01, -3.8272e-01, -2.8339e-02, -2.2897e-01, -5.6983e-01,
-2.8512e-01, 2.1405e-01, 1.4244e-01, -2.4765e-01, -9.2285e-01,
-1.5870e-01, 3.1537e-01, 6.5992e-03, -1.0407e-01, 1.2143e-01,
-5.1606e-01, 3.7502e-03, 2.3441e-01, 3.5735e-01, 4.5859e-01,
4.7626e-01, 1.3854e-02, 2.7567e-01, -5.1569e-01, 1.0974e-01,
-8.6881e-01, 1.5044e-02, -8.9788e-01, 1.7667e-01, -2.6201e-01,
-4.1427e-01, -5.5246e-01, 1.1560e-01, -8.8190e-01, 4.5456e-01,
1.9235e-01, 1.2921e-01, 1.5229e-01, -1.9987e-01, -2.6027e-01,
-3.5596e-01, -6.7304e-01, 4.9571e-01, 3.9356e-01, 2.1336e-01,
-4.4343e-01, 1.8197e-01, -8.4461e-02, -8.0682e-01, 5.1841e-01,
-3.9681e-01, -9.6569e-01, 5.3724e-01, 9.0561e-01, 3.3877e-01,
1.2078e-01, -5.5781e-01, -4.8962e-01, 2.8734e-01, 1.6586e-01,
7.1331e-01, 6.1627e-01, -7.2947e-02, -3.1722e-01, -4.4668e-01,
7.1677e-01, 7.0226e-01, -2.5982e-02, 3.9351e-01, 5.3223e-01,
-8.1562e-02, -4.6633e-01, -9.0438e-01, -2.8046e-01, 7.9335e-01,
-2.2949e-01, 4.2579e-01, 5.2176e-01, -8.8465e-01, -1.1302e-01,
-6.7228e-02, -8.5957e-01, -3.3596e-01, 2.3225e-01, 2.1879e-01,
-2.8585e-01, -1.0740e+00, -4.4953e-02, 4.8625e-01, -4.4345e-01,
2.9393e-01, -3.2396e-01, 5.8943e-01, 3.1865e-01, -5.7953e-01,
2.0403e-01, -4.3331e-01, 3.7139e-01, -3.9763e-01, -1.1528e+00,
4.5318e-01, -4.3044e-01, 1.5643e-01, -4.6508e-01, -1.1193e-01,
-4.0867e-01, -1.9700e-02, 6.7668e-01, 2.2033e-01, -1.6736e-01,
1.3936e+00, 3.9112e-01, 8.8215e-02, 4.3380e-01, 8.0549e-01,
4.1429e-01, -7.0422e-01, 4.8665e-01, -7.4544e-02, 1.5231e-02,
-7.6182e-02, 4.3843e-01, -3.1276e-01, 2.0777e-01, 5.3201e-01,
9.3685e-02, 3.1377e-01, -3.5719e-02, -4.9915e-02, -5.2044e-01,
1.1756e-01, 2.8682e-01, 1.3714e-01, -2.2988e-01, -2.4418e-01,
1.6720e-01, 2.9759e-01, 4.3159e-01, 6.8523e-01, 5.6962e-01,
-4.2079e-01, -6.5738e-01, -5.1392e-01, 6.2424e-01, -9.0328e-01,
-9.1176e-02, 3.9679e-01, 9.1325e-02, -1.5583e-01, 5.1478e-01,
2.8920e-01, -5.4417e-01, 3.4201e-01, -5.1361e-01, -4.2437e-01,
3.0082e-01, -7.7538e-01, -7.3559e-01, 8.8083e-01, -2.1521e-03,
-2.1420e-01, 1.4137e-01, -6.6280e-01, 3.2252e-02, 9.5490e-01,
-1.3150e-01, 2.9959e-01, -7.8910e-01, 1.2382e-01, 5.5384e-01,
-1.3550e-01, 9.1347e-01, -1.2900e-01, 1.4668e-01, 7.3439e-02,
-5.6929e-01, -3.2125e-02, 5.8019e-01, 5.9718e-01, -1.2096e-01,
-1.8825e-01, 2.1912e-01, -1.2869e-02, -3.2327e-01, 8.3701e-01,
3.6511e-01, -5.0949e-01, 4.3689e-01, 4.2039e-01, 3.3523e-01,
-1.4535e-01, 1.7775e-01, 2.7162e-01, -1.0384e-01, -6.7248e-01,
4.3916e-01, 4.8487e-01, 2.8187e-01, 5.1775e-02, -1.5917e-01,
-3.2313e-01, -3.9410e-01, -2.0976e-01, 1.2613e+00, -1.8666e-01,
2.6674e-02, 4.3957e-01, -6.7899e-01, 7.7242e-01, -1.2744e-01,
2.0697e-01, 3.0400e-01, -7.0376e-02, 3.0474e-02, -6.7896e-01,
3.8687e-01, -9.3908e-02, 1.1426e-01, -2.3242e-01, 6.5838e-01,
2.3110e-01, 4.3361e-01, 5.3303e-03, -2.1177e-01, 2.9668e-01,
-6.0936e-01, 3.1943e-02, -3.3331e-01, 2.6662e-01, 3.7100e-01,
3.3685e-01, -9.2457e-01, 7.5390e-02, 2.6468e-01, 2.6822e-01,
6.3900e-01, -1.0035e-01, 3.4232e-01, 7.6937e-01, 2.2153e-01,
-4.5401e-01, -4.3744e-01, -3.7021e-01, -3.4516e-01, -4.5415e-01,
-4.1240e-01, 2.3187e-01, -2.4430e-01, -1.3763e-01, -5.5673e-03,
3.8926e-01, -1.5492e-01, 6.5753e-02, -5.1204e-02, 8.4820e-01,
1.1431e-01, 9.5372e-02, -2.3413e-01, 2.4580e-01, 4.2175e-01,
1.5669e-01, -2.6868e-01, -1.1208e-01, -1.4662e-01, 5.3896e-01,
1.2518e-01, 4.1622e-01, -4.4564e-01, -2.6589e-01, 8.1437e-02,
-2.4076e-01, 2.4049e-01, -1.6442e-01, 2.5427e-02, 1.3722e-01,
4.3244e-01, 5.8983e-02, 5.3836e-01, 2.7516e-01, -3.2266e-01,
-3.0879e-01, -9.9352e-02, -7.6987e-02, 4.1110e-01, 1.9473e-01,
-7.1932e-02, -8.9015e-02, -4.1792e-02, -3.8190e-01, -2.0469e-01,
1.7515e-01, 4.0488e-01, 9.2956e-01, -2.8163e-01, -2.4475e-02,
-1.7623e-01, 5.6929e-01, 4.3190e-02, 3.4747e-01, 3.1614e-01,
1.2203e-01, -3.5249e-02, 2.4044e-01, 1.4024e-01, -6.3098e-03,
3.6867e-01, 6.8904e-02, -7.5329e-02, 3.4514e-01, -3.0747e-02,
-1.8248e-01, -4.6234e-01, 1.4077e-02, -4.4797e-01, 5.9934e-01,
-6.6947e-02, 1.4186e-01, -8.3329e-02, 1.3132e-01, -3.1456e-01,
2.1163e-01, 1.8594e-01, -3.3185e-02, -3.0710e-01, 2.7586e-01,
4.6186e-01, 1.4900e-01, -1.1982e-01, 6.3493e-01, -6.7865e-01,
-6.0139e-01, -3.2463e-01, -3.2313e-01, -1.0441e+00, 4.9654e-01,
4.0312e-01, -3.7118e-02, -4.3595e-02, -5.5223e-01, -4.3344e-01,
-3.2709e-01, 3.5101e-01, -3.2086e-01, 2.0470e-02, -6.8950e-01,
-5.1043e-01, 8.2606e-01, -8.7102e-01, 1.0237e+00, -3.0183e-01,
7.5128e-01, 6.3262e-01, -5.7180e-01, 3.4486e-03, -7.3666e-02,
-4.2316e-01, 1.4330e-01, 4.4132e-01, -1.2656e-01, -3.6106e-01,
1.4658e-01, 1.5778e-01, -2.1932e-01, -8.2727e-02, -7.2535e-01,
8.6657e-03, 8.6840e-01, 1.2020e-01, -2.4137e-01, -9.7683e-01,
2.4990e-01, 2.1827e-01, -6.5686e-01, 7.9211e-01, 4.2269e-01,
5.9635e-02, -6.6306e-01, -5.0252e-01, -3.9751e-01, 3.1867e-01,
2.1524e-02, -3.7938e-01, -1.0254e-01, 1.9931e-01, 9.5466e-02,
-3.6782e-01, 2.6247e-02, -6.5561e-01, -4.7974e-02, -3.4380e-01,
-4.5505e-01, 2.2059e-01, 5.3573e-02]])
記事毎にembeddingする
# 各テキストのベクトルを取得する
index_to_id = dict()# DataFrameのオリジナルのindex(key) -> 処理上のindex(value)
id_to_index = dict()# 処理上のindex(key) -> DataFrameのオリジナルのindex(value)
id_to_title = dict()# 処理上のindex(key) -> 記事タイトル(value)
tmp_list_embedding = []
for i, (index,item) in enumerate(df_input.iterrows()):
index_to_id[index] = i
id_to_index[i] = index
id_to_title[i] = item["title"]
content = item["content"]
embedding = embed(content)
tmp_list_embedding.append(embedding)
print(index_to_id)
print(id_to_index)
stacked_tensors = torch.cat(tmp_list_embedding, dim=0)
print(stacked_tensors.shape)
実行結果
{875: 0, 880: 1, 885: 2, 2282: 3, 2283: 4, 2294: 5, 5633: 6, 5637: 7, 5646: 8}
{0: 875, 1: 880, 2: 885, 3: 2282, 4: 2283, 5: 2294, 6: 5633, 7: 5637, 8: 5646}
torch.Size([9, 768])
変換前のベクトルで類似度行列を計算してみる
# 類似度行列(cos類似度)を計算する
def cos_sim(tensors: torch):
import torch.nn.functional as F
# 各ベクトルを L2 ノルムで正規化
normalized_tensors = F.normalize(tensors, p=2, dim=1)
# コサイン類似度行列を計算
similarity_matrix = torch.matmul(normalized_tensors, normalized_tensors.transpose(0, 1))
return similarity_matrix
def visualise_cos_sim_mat(sim_mat: pd.DataFrame):
import matplotlib.pyplot as plt
import seaborn as sns
sns.heatmap(sim_mat, cmap='bwr', annot=True, fmt=".2f")
cs = cos_sim(stacked_tensors)
visualise_cos_sim_mat(sim_mat = pd.DataFrame(cs))
野球系の記事同士は0.86と非常に高いことは期待通りだが、やはり似ていないであろう記事同士、例えば0番の野球の話と6番のdokujo-tsushinとが0.72となっており、違和感を覚える結果となった。
結果を見てみると、全く似ていない記事でも、0.5以上あるような状態であり、筆者の主観であるが、0.7よりも大きければ似ている記事、それ以下なら似ていない記事と判断できそうな状態である。
しかし、これは直感的とはいいがたいと思う。似ていない記事はもっと似ていないことがわかる数値(例えば0)を出して欲しい。次のところで、BERT-whiteningでそれが達成されるか確認する。
BERT-whiteningする
def whitening(tensors: torch) -> (torch, torch, torch):
"""
データ行列(今回は、9×768の行列)を新しい直交空間(座標)での表現に変換する。
"""
# ----------------------------------------------------------
# step1.分散共分散行列を計算する
# 各特徴量の平均を計算
mean = torch.mean(tensors, dim=0)
# 各ベクトルから平均ベクトルを引き算
centered_tensors = tensors - mean
# 分散共分散行列を計算
covariance_matrix = torch.matmul(centered_tensors.transpose(0, 1), centered_tensors) / (tensors.shape[0] - 1)
# ----------------------------------------------------------
# step2.対角化して、(単位ベクトル化済みの)固有ベクトルを並べたテンソル(768×768)を作る
# 固有値固有ベクトルを計算
eigenvalues, eigenvectors = torch.linalg.eig(covariance_matrix)
# 固有ベクトルを単位ベクトルに変換
eigenvectors = eigenvectors / torch.norm(eigenvectors, p=2, dim=0, keepdim=True)
# 固有ベクトルを縦ベクトルと見なしてテンソルを作成
eigenvector_tensor = eigenvectors.real # 複素数を実数に変換 (必要な場合)
# ----------------------------------------------------------
# step3.
# 上記で求めた固有ベクトルが張る空間に
# 中心化済みのオリジナルデータ行列(9×768)を射影する
# 中心化済みのオリジナルデータ行列をcentered_tensorsとして、新しいデータ行列をXとすると、
# X = (centered_tensors @ eigenvector_tensor).transpose()
# テンソル積を計算
X = torch.matmul(centered_tensors, eigenvector_tensor).transpose(0, 1)
X_new = X.transpose(0, 1)
return X_new, mean, eigenvector_tensor
X_new, mean, U = whitening(stacked_tensors)
print(X_new.shape)
cs = cos_sim(X_new)
visualise_cos_sim_mat(sim_mat = pd.DataFrame(cs))
色合いはあまり変わらないように見えるが、2点大きな違いがある。
- さっきよりも値の大きさがダイナミックに異なっている。例えば、野球系の記事同士は0.56で、0番の野球の話と6番のdokujo-tsushinとが-0.1。さっきは0.86と0.72とちょっとの差しかなかったが、今は大きな差となっている。
- カラーマップのスケール(ヒートマップの右にある棒)を見て欲しい。さっきまでは一番小さくても0.5程度だったのに、今回は-0.4くらいまである。
beforeとafterとで見やすくなるようにヒートマップを横並びにする
def visualise_heatmap_horizontally(
sim_mat1: pd.DataFrame,
sim_mat2: pd.DataFrame
) -> None:
import matplotlib.pyplot as plt
import seaborn as sns
# FigureとAxesの作成 (1行2列のsubplots)
fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # figsizeで全体のサイズを調整
sns.heatmap(sim_mat1, ax=axes[0], cmap='bwr', annot=True, fmt=".2f")
axes[0].set_title("before")
sns.heatmap(sim_mat2, ax=axes[1], cmap='bwr', annot=True, fmt=".2f")
axes[1].set_title("after")
# レイアウト調整 (必要であれば)
plt.tight_layout()
plt.show()
# 記事のindexだとわかりづらいので、
# タイトルで表示できるようにする
# ただ、タイトルをそのまま表示しようとすると長すぎるので、
# タイトルの冒頭4文字くらいだけ表示する
id_to_title_new = [s[:4] for s in id_to_title.values()]
print(id_to_title_new)
visualise_heatmap_horizontally(
sim_mat1 = pd.DataFrame(cos_sim(stacked_tensors), index=list(id_to_title_new), columns=list(id_to_title_new)),
sim_mat2 = pd.DataFrame(cos_sim(X_new), index=list(id_to_title_new), columns=list(id_to_title_new))
)
実行結果。
こうやって見ると、before(変換前)は微妙な差しかなかったものが、after(変換後)は、その差がより大きくなるように変換されていることがわかる。これであれば、直感と一致しやすい。
▼おまけ
cos類似度の行列(今回は9記事だから、9×9の行列)の上三角部分だけの値を抽出し、そのヒストグラムを出力させた。異なる記事同士の類似度がどのような分布になっているかを確認できる。
気体としては、変換前は0.6に集中してしまっており、変換後は0に近いところに山があってほしい(ある記事1つが他の8記事とはほぼ似ていない、という状況なので)、という期待で実施してみた。
実験の結果、beforeは期待通り0.6近辺
実験まとめ
以上の結果から、BERT-whiteningをすることである程度期待通りの出力を得られることが分かった。言い換えると、似ている記事と似ていない記事をもっとゆるく判断できるということ。これまでは、cos類似度を使ってテキスト同士の類似性をジャッジするには閾値を設定するなどの必要があったと思うが、その閾値は、0.001のようなオーダーで調整しなくてはいけなかった。しかし、BERT-whiteningをすれば、もっと大きな数値でその閾値の調整ができるようになる。それはつまり、テキストの持つ機微をとらえやすくなった、ともいえるので、そういうメリットがあると思う。
追加検証
ここまでは、すでに手元にあるテキストに対して変換していたが、実務では、新しいテキストが追加されていく。追加されたテキストに対してもBERT-whiteningが機能するかどうかを検証してみる。
- 追加検証1:類似記事での出力が"類似"と判断できる変換ができているか?
- 追加検証2:似ていない記事での出力が"似ていない"と判断できる変換ができているか?
追加検証1
# 追加検証1
# 似ている記事同士で似ていると判断できるか?
titles = [
"高画質と小型・軽量化を両立!キヤノンのミラーレス一眼「EOS M」を見てきました\n", # オリジナル
"スティックタイプのコンデジ登場! Optio新モデルは気分に合わせて着せ替え可能\n", # 類似してる(孫社長がいるという意味で)記事
]
# DataFrameを作成する
df_input1 = df.loc[df["title"].isin(titles)]
display(df_input1)
# embedding
tmp_list_embedding = []
for i, (index,item) in enumerate(df_input1.iterrows()):
content = item["content"]
embedding = embed(content)
tmp_list_embedding.append(embedding)
stacked_tensors1 = torch.cat(tmp_list_embedding, dim=0)
# データ行列(2×768行列)を変換する
# さっき計算して算出していた`mean`と固有ベクトルを並べたテンソル `U` を利用する
tensors1 = stacked_tensors1 - mean
X_new1 = torch.matmul(tensors1, U).transpose(0, 1)
X_new1 = X_new1.transpose(0,1)
print(X_new1.shape)
# 可視化
#cs = cos_sim(X_new1)
#visualise_cos_sim_mat(sim_mat = pd.DataFrame(cs))
visualise_heatmap_horizontally(
sim_mat1 = pd.DataFrame(cos_sim(stacked_tensors1)),
sim_mat2 = pd.DataFrame(cos_sim(X_new1))
)
もともとのcos類似度が0.9だったものが0.81とやや下がってしまっているが、依然として類似性が高いと言える水準であるので、問題ない。
追加検証2
# 追加検証2
# 似ていない記事を似ていないと判断できるか?
titles = [
"地方女子と東京女子の幸せの価値観\n", # オリジナル記事
"今度生まれ変わるなら女or男どちらになりたい?\n", # オリジナル記事
"親が倒れた時、独女は何ができるのか?\n", # オリジナル記事
"ダイエット不要!ぽっちゃり系の時代が到来!?\n", # dokujo-tsushinカテゴリの中で似ていない記事
]
# DataFrameを作成する
df_input2 = df.loc[df["title"].isin(titles)]
display(df_input2)
# embedding
tmp_list_embedding = []
for i, (index,item) in enumerate(df_input2.iterrows()):
content = item["content"]
embedding = embed(content)
tmp_list_embedding.append(embedding)
stacked_tensors2 = torch.cat(tmp_list_embedding, dim=0)
# データ行列(2×768行列)を変換する
# さっき計算して算出していた`mean`と固有ベクトルを並べたテンソル `U` を利用する
tensors2 = stacked_tensors2 - mean
X_new2 = torch.matmul(tensors2, U).transpose(0, 1)
X_new2 = X_new2.transpose(0,1)
print(X_new2.shape)
# 可視化
visualise_heatmap_horizontally(
sim_mat1 = pd.DataFrame(cos_sim(stacked_tensors2)),
sim_mat2 = pd.DataFrame(cos_sim(X_new2))
)
もともとの埋め込みだと、似ていないはずの記事(3番とそれ以外)が最低でも0.69と出ており、かなり高く出ている。上の方で述べていた『「0.7」以上だと似ていると判断できそう』という基準で考えると、今のままだとかなり似ていると判断できてしまうだろう。
一方で、変換後であれば、最も似ていない記事との類似度が0.01となっており、最も似ている記事でも0.44である。
まとめ
- 埋め込みした時に、似ていないはずのテキストもcos類似度で見ると似ているように見えてしまう、という問題に対して、解決策を調べた。
- その中で、理論や実装がしやすいBERT-whiteningという手法に関してなぜそれが問題解決できるのかの概念的な説明をした。
- livedoorのニュース記事コーパスをりようしてBERT-wthiteningという手法の実装も行い、実験によってその有用性を確かめた。