概要
Sentence transformerを使って、ファインチューニングするためのコードを書きました。
例として、livedoorニュースデータを使っています。
sentence trasnformersと、ファインチューニングについて簡単に紹介した後、実装のコードを掲載します。
Google Colabで実行できます。コードはこちらです。
Sentence transformerとは何か?
BERT + pooling層の2層構造のモデルです。
何ができるのか?
インプットをテキスト(文字)として、アウトプットはベクトルを出してくれます。
要は、テキストの埋め込みをしてくれるモデルとなっています。
何に使えるの?
テキストを埋め込むということはテキストがベクトル化されます。ベクトル化されていると、ベクトル同士の類似度や距離の計算ができるようになるので、テキスト検索やレコメンドなどにつかえるようになります。
- テキスト検索
- ユーザが入力した検索文(クエリ)や文字情報を埋め込みし、既存のDBの中にあるテキストと類似度を計算する。その類似度の高いテキストを検索結果として表示する。
- レコメンド
- ユーザが過去に読んだテキストを使って、そのユーザの過去の読み込んだ記事と類似度の高いテキストを表示する。
最近流行のChatGPTを自社ドメイン知識に合わせた回答を作り出す際も、このようなテキストのベクトル化(埋め込み)という技術を使っています。(テキストの埋め込みはいろんな手法があります。Sentence transformersはその手法のうちの1つです。ChatGPTはまた別の方法でテキストの埋め込みを実現しています。)
課題点
sentence transformerモデルは学習済みのBERTモデルに対して、pooling層を結合するだけで精度の良いテキスト埋め込みをしてくれるモデルです。しかし、学習済みのBERTモデルがベースとなっているため、特定のドメイン知識を持った埋め込みができません。
詳しく説明するために例を挙げて説明します。
実際のアプリケーションでは、自分の会社が保有しているテキストデータに特化したテキスト埋め込みをしたいときがあります。
例えば、マーケティングに特化した記事を配信する事業をしている場合は、一般的な知識を基にした埋め込みをするよりも、マーケティングに特化した埋め込みをしてくれる方がうれしいです。
マーケティング用語に「CV(コンバージョン)」という言葉が頻出します。(マーケターが設定している目的のアクションのことをCVと言います。)
アニメや邦訳した海外作品などでは、「CV」と言えば、キャラクターボイスという意味になります。
要するに、扱っている領域(ドメイン)に応じて単語が持つ意味が変わってくるということです。これを反映した形でテキストの埋め込みをした方が、より精度の良い検索、レコメンドができるような気がしますよね。
これを実現するために、ファインチューニングという方法が必要です。
ファインチューニング
なぜファインチューニングするのか?
ファインチューニングでは、次の2つのものが必要になります。
- 学習済みのBERTモデル
- チューニング用のデータ
学習済みのモデルは、いろんな方がBERTモデルを学習し、そのモデルを公開してくれています。
日本語の場合は、東北大が公開してくれているBERTモデルが有名です。
今回も東北大のモデルを使わせてもらいます。
なぜ学習済みのモデルを使うのかと言えば、学習時間を削減するためです。BERTモデル(を含め、Transoformerベースの言語モデル)は自分で学習するのが大変です。
言語処理をやっているので、コーパスを作ったり、適切に分かち書きをしたり、、、が大変です。
またモデルを学習するには、モデルのパラメータのチューニングや学習のために非常に長い時間がかかります。時間を削減するには性能の良いGPUが必要です。
といったように、自前でモデルを作るには時間やお金がかかります。
その手間を軽くしつつ、自分が持っているデータ(ドメイン)に特化させたモデルを作るためにファインチューニングという方法を採用します。
ファインチューニングするには?
今回は次のような設定でファインチューニングします。
- データ:livedoorニュースデータをTripletデータに加工する。
- 損失関数はTriplet Loss関数で距離関数としてはユークリッド距離を使用しメトリック学習する
Tripletデータとは?
元々は画像認識系で使われていた手法ですが、それをテキストデータに対して応用した方法です。
anchor、positive、negativeという3つのカラムを持つデータを用意します。
- anchor:基準となるデータです。
- positive:anchorと同じカテゴリのデータです
- negative:anchorとは異なるカテゴリのデータです。
livedoorニュースでは、カテゴリには次の9つがあります。
- dokujo-tsushin
- it-life-hack
- kaden-channel
- livedoor-homme
- movie-enter
- peachy
- smax
- sports-watch
- topic-news
例えば、anchorとして選んだ記事が「dokujo-tsushin」だったとします。この場合、positiveなのはanchorと同じカテゴリを持つ「dokujo-tsushin」の記事だけとなります。
一方で、negative記事となるのは、「dokujo-tsushin」以外のカテゴリの記事です。
要するに、anchorとpositiveは同じカテゴリ同士の記事で、negativeはanchorとは異なるカテゴリであれば良いのです。
※Tripletデータは、ファインチューニング時に随時作成する方法などもありますが、わかりやすさ優先で今回はこのような説明と実装にしています。
Triplet学習とは?
メトリック学習の一つの方法です。
数式や文字での説明よりも図で見た方がイメージがわかりやすいので、↓に載せています。
anchor-positiveの距離を近づけ、anchor-negativeの距離を遠ざけるように学習する方法です。
距離の指標としてはいくつかの指標がありますが、一般的にはユークリッド距離を使うことが多いようです。今回もユークリッド距離を使います。
似ている記事同士(同じカテゴリ同士)は近づけて、似ていない記事(別カテゴリの記事同士)は遠ざけるように学習します。
注意点
sentence transformersモデルは扱えるテキストの長さに上限があります。(これはChatGPTでトークンサイズに上限があるのと同じ理由です。)
上限サイズは使っているBERTモデルに依存します。BERTモデルが256トークンまでであれば、256トークンまでを使ってテキスト埋め込みを実装します。257トークン目以降のトークンはすべて無視されてしまいます。
これを回避するには、いくつかの方法があります。
- テキストをうまく切り貼りしてテキストの要素をトークン上限に収める。例えば、文章というのは文章の冒頭と末尾に重要な情報があるのが一般的です。そこで、テキストの冒頭と末尾だけを抽出し、それを使ってテキストの埋め込みを実施する。という方法があります。今回実装するのはこの方法に近いです。今回は冒頭の512トークンだけ使います。
- 何らかの方法でテキストを要約する。ChatGPTなどは要約が上手なので、ChatGPTなどでテキストデータをうまく要約させてトークンサイズの上限に引っかからないようにします。
- テキストを複数のチャンクに分割して、分割したチャンクで埋め込みを実施する。各チャンクのベクトルを平均化する。例えば、1,300トークンあるテキストがあり、トークン上限が512のBERTモデルがあったとすると、テキストを500、500、300に分割し、それぞれのチャンクで埋め込みを作ります。そのあと、各ベクトルを平均化して、そのテキストの埋め込みとする方法です。
今回は東北大が提供しているBERTモデルを使いますが、これは上限が512トークンなので、ファインチューニングでは文章の冒頭512トークンだけを使ってチューニングしていることになります。今回の場合は、ニュース記事なので、冒頭に重要な文章が集中していると仮定し、文章全体の要素は冒頭512トークンで表現されきっているものと考えることにします。(本来はこの仮定が正しいかどうかの検証などが必要になりますが、方法やコードの解説が主目的なので、今回はその確認は行っていないです。)
実装
まずはlivedoorニュースをDLし、それをTripletデータの形式に直します。
そのあと、Sentencetransformerモデルを構築して、ファインチューニングするという流れで実装しています。
データの用意
データのDL
ライブドアニュースのデータのDL
データを読み込んだ後は、解凍させる データのDL自体はコマンドを使った方がラクなので、"wget"コマンドでDLする。
DLしたデータはtar.gz形式で圧縮されているので、それを解凍する必要があるので、"tarfile"ライブラリを使って、解凍している。
# tar.gzファイルの解凍のためのライブラリのインポート
import tarfile
# データのDL
!wget "https://www.rondhuit.com/download/ldcc-20140209.tar.gz"
# ファイルの解凍
tar = tarfile.open(SOURCE_FILE, 'r:gz')
tar.extractall(TARGET_FILE)
# フォルダのファイルとディレクトリを確認
files_folders = [name for name in os.listdir("livedoor/text/")]
print(files_folders)
# カテゴリーのフォルダのみを抽出
categories = [name for name in os.listdir(
"livedoor/text/") if os.path.isdir("livedoor/text/"+name)]
print(categories)
# ファイルの中身を確認してみる
file_name = "livedoor/text/movie-enter/movie-enter-6255260.txt"
with open(file_name) as text_file:
text = text_file.readlines()
print("0:", text[0]) # URL情報
print("1:", text[1]) # タイムスタンプ
print("2:", text[2]) # タイトル
print("3:", text[3]) # 本文
フォルダ名称が「カテゴリ」で、フォルダの中に大量にテキストファイルが存在している。
テキストファイルは、
1行目:URL情報
2行目:タイムスタンプ(記事が公開された日時?)
3行目:タイトル
4行目:本文 というデータで入っている。
テキストファイルをカテゴリ別に読み込む
import re
import pickle
import pandas as pd
# フォルダのファイルとディレクトリを確認
files = os.listdir("livedoor/text/")
files_dir = [f for f in files if os.path.isdir(os.path.join("livedoor/text/", f))]
list_df = []
for category in files_dir:
path = "livedoor/text/" + category + "/"
files = os.listdir(path)
files_file = [f for f in files if os.path.isfile(os.path.join(path, f)) and (re.match(category, f))]
for file_name in files_file:
file = path + file_name
with open(file) as text_file:
text = text_file.readlines()
list_df.append([category, text[2], text[3]])
#text[0] : # URL情報
#text[1] : # タイムスタンプ
#text[2] : # タイトル
#text[3] : # 本文
with open('list_df.pkl', mode='wb') as f:
pickle.dump(list_df, f)
df = pd.DataFrame(list_df, columns = ["category", "title", "text"])
df.head()
出力結果はこちら。
Tripletデータの用意
- 学習とテスト用のデータに分割してファインチューニング前後の差を見れるようにするために、10,000ペアのデータを作る。
- つまり、anchor-positiveとanchor-negativeペアを10,000個作る
anchor-positiveペアの作成
- ①7,367記事のインデックスで全組み合わせを作る。$7,367C_2 = 27,132,661$
- ②その組み合わせの中からランダムに1ペアを選ぶ
- ③選んだ組み合わせが同じカテゴリであれば、anchor-positiveとして登録する
- ④anchor-positiveとして登録された組み合わせは、組み合わせリストから削除する
- ①~④を繰り返して、10,000ペア作り出す
import itertools
import random
import numpy as np
random.seed(0)
# ①7,367記事のインデックスで全組み合わせを作る。
# ペアの総数の定義
N_pair = 10_000
# 記事のレコード数を計算
N_RECORDS = df.shape[0]
list_records = [i for i in range(N_RECORDS)]
# インデックスの組み合わせを計算する
combinations = list(itertools.combinations(list_records, 2))
count = 0
list_anchor_text = []
list_positive_text = []
list_anchor_cat = []
list_positive_cat = []
while count < N_pair:
selected = random.choice(combinations)
# ②その組み合わせの中からランダムに1ペアを選ぶ
# インデックスの候補を取得
anchor_candidate_ind = selected[0]
positive_candidate_ind = selected[1]
# カテゴリの候補を取得
anchor_candidate_cat = df['category'][anchor_candidate_ind]
positive_candidate_cat = df['category'][positive_candidate_ind]
# 同じカテゴリかどうかを判定するための変数 if文で使う。
is_same_cat = anchor_candidate_cat==positive_candidate_cat
# ③選んだ組み合わせが同じカテゴリであれば、anchor-positiveとして登録する
# 同じカテゴリの場合の処理
if is_same_cat:
# それぞれのテキストとカテゴリをanchor-positiveとして登録する
anchor_text = df['text'][anchor_candidate_ind]
positive_text = df['text'][positive_candidate_ind]
list_anchor_text.append( anchor_text)
list_positive_text.append(positive_text)
list_anchor_cat.append( anchor_candidate_cat)
list_positive_cat.append( positive_candidate_cat)
# ④anchor-positiveとして登録された組み合わせは、組み合わせリストから削除する
# 過去に抽出したインデックスを再度選ばないようにcombinationsリストを一部削除
combinations.remove(selected)
# countを増やす
count += 1
# DataFrameで格納する。
dict_df_anc_pos = {
"anchor" : list_anchor_text,
"positive" : list_positive_text,
"anchor_cat" : list_anchor_cat,
"positive_cat" : list_positive_cat
}
df_anc_pos = pd.DataFrame(dict_df_anc_pos)
df_anc_pos
anchor-negativeペアの作成
- ①7,367記事の中から1つ記事を選ぶ
- ②anchor-positiveと別のカテゴリの場合は、anchor-negativeとして登録する
count = 0
list_negative_text = []
list_negative_cat = []
while count < N_pair:
# ネガティブのインデックス候補を取得
negative_candidate_ind = random.sample(range(0, N_RECORDS), 1)[0]
# ネガティブカテゴリの候補を取得
negative_candidate_cat = df['category'][negative_candidate_ind]
# anchorのカテゴリを取得
positive_cat = df_anc_pos['anchor_cat'][count]
# anchorカテゴリとネガティブカテゴリが違うかどうかを判定する変数を定義
is_different = negative_candidate_cat != positive_cat
# anchorと異なるカテゴリの場合にanchor-negativeとして登録する
if is_different:
negative_text = df['text'][negative_candidate_ind]
list_negative_text.append(negative_text)
list_negative_cat.append( negative_candidate_cat)
# countを増やす
count += 1
# DataFrameで格納する。
dict_df_neg = {
"negative" : list_negative_text,
"negative_cat" : list_negative_cat,
}
df_neg = pd.DataFrame(dict_df_neg)
df_neg
anchor-positive-negativeを結合しTripletデータを保存する
df_data = pd.concat([df_anc_pos, df_neg], axis=1)
df_data = df_data[["anchor", "positive", "negative", "anchor_cat", "negative_cat"]]
# 保存する
df_data.to_csv("/largest_data.csv", encoding='utf-8-sig', index=False)
ファインチューニングする
ライブラリのインストール
sentence transformersをインストールすると自動的にtransformersなどもインストールされる。
※Google Colabを使っている場合は、ランタイムをGPUに変更しておくと良いです。
# sentence transformersを入れると、huggingfaceも自動で入るので、ラク
!pip install -U sentence-transformers
# 東北大のbertを使うには、fugachiが必要
!pip install fugashi
# 東北大のbertを使うには、unidic-liteが必要
!pip install fugashi[unidic-lite]
# matplotlibの日本語文字化け対策
!pip install japanize-matplotlib
ファインチューニングや、結果の可視化などで使うためのライブラリをインポートします。
# ライブラリのインポート
# transformer系のライブラリ
import transformers
transformers.BertTokenizer = transformers.BertJapaneseTokenizer
# sentence transformersライブラリ
import sentence_transformers
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
from sentence_transformers.losses import TripletDistanceMetric, TripletLoss
from sentence_transformers.readers import TripletReader
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.datasets import SentencesDataset
from sentence_transformers import InputExample, losses
from torch.utils.data import DataLoader
# その他ライブラリ
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import japanize_matplotlib #日本語化matplotlib
import seaborn as sns
import warnings
warnings.simplefilter('ignore')
モデルの定義
今回は、東北大がhugging faceで提供しているBERTモデルを使用する。
# モデル名を定義
# 今回は、東北大がhugging faceで提供しているBERTモデルを使用する
MODEL_NAME = 'cl-tohoku/bert-base-japanese-v3'
# BERTモデルのロード
bert = models.Transformer(MODEL_NAME)
# プーリング層の定義
pooling = models.Pooling(
bert.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True, # 埋め込みは mean で実施する
)
# BERTモデルとプーリング層を使って、Sentence BERTまたはSentence Transformersモデルを定義
model = SentenceTransformer(modules=[bert, pooling])
# 後で次の処理をしたいので、一旦上で定義したほぼ生の状態のsentence transformersモデルを保存しておく
# <処理>
# ファインチューニング前と後のモデルとで埋め込み精度の比較
with open(DIR + 'model_before_fit.pickle', 'wb') as p:
pickle.dump(model, p)
学習用とテスト用にデータを分割する
- 学習:8.5割
- テスト:1.5割
で分割する。
※ Google Colabの無料版だと、10,000ペアのデータを使うと、時間がかかりすぎてしまうので、今回は、10,000ペアのうち、5,000ペアを学習、テストで使うことにする。
Sentence transformersを使うときはデータをInputExample
を使ってラッピングする。
# 学習用、検証用に分割する
df_ = df_data.head(5000)
df_train, df_eval = train_test_split(df_, test_size=0.15, random_state=42)
# 学習用データのセッティング
train_dataset = SentencesDataset([InputExample(texts=[row["anchor"], row["positive"], row["negative"]]) for index,row in df_train.iterrows()]
,model)
ハイパーパラメータの調整
無料枠のGPUのメモリでは、バッチサイズを8以上にするとメモリ不足になってしまうので、↓の設定にすることとしました。
# ハイパーパラメータの定義
# バッチサイズやウォームアップは Sentence BERT の論文の “3.1 Training Details” の記述に従いたいが、
# GPUのメモリを使い切ってしまう問題などが出たため、論文の設定よりもかなり小さくしている点に注意。
BATCH_SIZE = 4
NUM_EPOCH = 10
EVAL_STEPS = 100
WARMUP_STEPS = int(len(train_dataset) // BATCH_SIZE * 0.1)
ファインチューニングの各種設定
evaluatorやDataloader、損失関数を定義する
evaluator = TripletEvaluator(df_eval.anchor.values.tolist()
,df_eval.positive.values.tolist()
,df_eval.negative.values.tolist()
)
# Dataloaderの定義
train_dataloader = DataLoader(
train_dataset,
shuffle = False,
batch_size = BATCH_SIZE,
)
# 損失関数の定義(ユークリッド距離で定義)
train_loss = TripletLoss(
model = model,
distance_metric = TripletDistanceMetric.EUCLIDEAN,
triplet_margin = 1,
)
学習の実行
model.fit(
train_objectives = [(train_dataloader, train_loss)],
epochs = NUM_EPOCH,
evaluator = evaluator,
evaluation_steps = EVAL_STEPS,
warmup_steps = WARMUP_STEPS,
output_path = DIR +"/sbert",
checkpoint_path = DIR + "/check_points",
checkpoint_save_steps = 100,
checkpoint_save_total_limit = 999_999
)
# 学習済みモデルの保存
with open(DIR + 'model_after_fit.pickle', 'wb') as p:
pickle.dump(model, p)
学習結果を評価する
ファインチューニング前後でモデルが学習をうまくできたかどうかを確認します。
ファインチューニング前後のモデルをロードする
# ファインチューニング前後のモデルをロードする
# 前
with open(DIR + 'model_before_fit.pickle', 'rb') as p:
model_before_fit = pickle.load(p)
# 後
with open(DIR + 'model_after_fit.pickle', 'rb') as p:
model_after_fit = pickle.load(p)
評価する
- 2つの指標を使って評価する
- 『accuracy』
- 『diff』
accuracy
positiveとnegativeとで類似度を比較したときに、類似度が高い方をpositiveと判定する2値分類のように考えると、positiveの方が類似度の高い時(=diff > 0の時)、positiveを正しくpositiveと判定していることになるので、その回数をカウントアップするような処理をしている。
いわゆる「accuracy」に近い考え方。
diff
diffは大きくなればなるほど良い指標。
そもそもとして、理想は、posとの類似度が高く、negとの類似度が低いこと。
ここで定義している「diff」はposとの類似度が高いかつnegとの類似度が低い時にdiffがより大きくなるように定義している。
したがって、diffはで書ければでかいほど良いという評価になる。
# 評価のための関数の定義
# accuracyとdiffを計算する
def evaluate(test_df, model):
"""
評価指標を計算するための関数
:param : test_df : DataFrame. anchor(str) | positive(str) | negative(str) のカラム構成
:model : sentence transformer
:return : dict. {"accuracy" : accuracy score, "diff", diff score}
"""
correct = 0
avg_diff = 0
# 各レコードのベクトル化と評価指標の計算を実行する
for index, row in df_test.iterrows():
text_anchor = row.anchor
text_positive = row.positive
text_negative = row.negative
# ベクトル化(anchor - positive)
vec_anchor = model.encode(text_anchor)
vec_positive = model.encode(text_positive)
vec_negative = model.encode(text_negative)
# cos類似度計算(anchor - positive、anchor - negative)
score_pos = cosine_similarity([vec_anchor], [vec_positive])[0][0]
score_neg = cosine_similarity([vec_anchor], [vec_negative])[0][0]
# 類似度の差分の計算
diff = score_pos - score_neg
# 類似度差分の足し算
avg_diff += diff
# 類似度の高い文章ペアをちゃんと区別できている件数のカウント
if diff > 0:
correct += 1
accuracy_score = correct / len(test_df)
diff_score = avg_diff / len(test_df)
return {"accuracy" : accuracy_score,
"diff" : diff_score}
# ファインチューニング 前 のスコアを計算する
score_before = evaluate(test_df = df_test
,model = model_before_fit)
# ファインチューニング 後 のスコアを計算する
score_after = evaluate(test_df = df_test
,model = model_after_fit)
diffとaccuracyを可視化する
data1 = [score_before["accuracy"], score_after["accuracy"]]
data2 = [score_before["diff"] , score_after["diff"]]
labels = ['before', 'after'] # ラベル
f, ax1 = plt.subplots()
ax2 = ax1.twinx()
# 棒グラフの幅
bar_width = 0.35
# X軸の位置
r1 = range(len(data1))
r2 = [x + bar_width for x in r1]
# 棒グラフの描画
ax1.bar(r1, data1, color='blue', width=bar_width, edgecolor='black', label='accuracy')
ax2.bar(r2, data2, color='orange', width=bar_width, edgecolor='black', label='diff')
# X軸の目盛りとラベル
ax1.set_xlabel('Data')
# Y軸の目盛りとラベル
ax1.set_ylabel("Accuracy")
ax2.set_ylabel("diff")
ax1.set_ylim([0.0, 1.0])
ax2.set_ylim([0.0, 0.5])
# 凡例の表示
h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax1.legend(h1+h2, l1+l2, bbox_to_anchor=(1.15, 1), loc='upper left', borderaxespad=0.3, fontsize=9)
# グラフのタイトル
plt.title('compare fine tuning results with raw model')
# グラフの表示
plt.show()
結果はこちら。
左の青とオレンジの棒グラフがチューニング前で、右の青とオレンジの棒グラフがチューニング後の結果です。
accuracyは今回の定義ではもともと75%位と悪くない数値でしたが、チューニング後はそれがグンと上昇して、95%くらいになっているのが見て取れます。言い換えると、同じカテゴリの記事のベクトル同士を近づけることができていて、異なるカテゴリの記事を離すことに成功していると言えます。
また、オレンジの棒グラフで比較すると、チューニング前は0.05くらいだったものが、チューニング後は0.38くらいにまで向上しています。言い換えると、チューニング後は、同じカテゴリ同士の記事の距離と異なる記事同士の距離の差分を広げることができていることと言えます。
t-SNEでファインチューニング前後でモデルが学習できているかどうかを確認する
# textをベクトル化
def get_vector_and_tsne(df, model, perplexity = 5):
"""
テキストデータをベクトル化して、それをt-SNEで2次元に圧縮するための関数
:df : DataFrame. text(str) カラムを持つDataFrame
:model : sentence transformers
:perplexity : perplexity. 一般的には、5~50に設定する
:return : DataFrame. テキストをベクトル化し、x1 | x2 にそれぞれ格納したDataFrame
"""
from sklearn.manifold import TSNE
# t-SNEを定義
tsne = TSNE(n_components=2, random_state = 0, perplexity = perplexity)
# 2次元に変換
X_embedded = tsne.fit_transform(model.encode(df.text.values.tolist()))
# DataFrameに格納
df_x_embedded = pd.DataFrame(X_embedded, columns = ["x1", "x2"])
# オリジナルのDataFrameに結合
df_result = pd.concat([df, df_x_embedded], axis=1)
return df_result
散布図で可視化する際に、必要なカラムだけに絞り込む。
# categoryカラムとtextカラムだけにする
df_vis = df[["category", "text"]]
# ファインチューニング 前
# t-SNEの実行
df_after_tsne = get_vector_and_tsne(df_vis, model_before_fit, perplexity = 50.0)
list_category = df_after_tsne.category.unique()
colors = ["r", "g", "b", "c", "m", "y", "k", "orange","pink"]
plt.figure(figsize = (10, 10))
for i , cat in enumerate(list_category):
tmp_df = df_after_tsne[df_after_tsne.category == cat]
plt.scatter(tmp_df['x1'],
tmp_df['x2'],
color = colors[i],
alpha=0.5,
label = cat)
plt.legend()
結果はこちら。
いろんなカテゴリの記事同士が重なり合ってしまっていて、分離できていないことがわかります。
# ファインチューニング 後
# t-SNEの実行
df_after_tsne = get_vector_and_tsne(df_vis, model_after_fit, perplexity = 50.0)
list_category = df_after_tsne.category.unique()
colors = ["r", "g", "b", "c", "m", "y", "k", "orange","pink"]
plt.figure(figsize = (10, 10))
for i , cat in enumerate(list_category):
tmp_df = df_after_tsne[df_after_tsne.category == cat]
plt.scatter(tmp_df['x1'],
tmp_df['x2'],
color = colors[i],
alpha=0.5,
label = cat)
plt.legend()
結果はこちら。
まだ、記事同士に重複がありますが、チューニング前の画像と比較すると分離できているように見えます。
より分離するためには
- 学習データをもっと増やす。今回は学習時間の関係上、Tripletデータを5,000ペアしか使わなかったですが、本来はもっとたくさんのデータペアで学習することができれば精度を上げることができます。
- 損失関数を変更する。Triplet学習にはいくつかのバリエーションがあります。
まとめ
今回は、setence transformersをファインチューニングして、livedoorニュースに特化させてテキスト埋め込みさせるモデルを作りました。
個人的には、古典的な機械学習しか勉強していなかったので、メトリック学習という手法が斬新でした。古典的なラベリングでは、今回のようなタスクを実施する場合は、マルチラベルの分類問題として解くことになると思うのですが、次のような問題にぶつかります。
学習データに存在しないラベルを予測することができない。しかし、今回のsentence transformerで実装したようにメトリック学習であれば、学習データに無いデータに対しても似ているか似ていないかを推測することができます。これが古典的な手法との大きな違いで大きなメリットだと感じました。 今回の場合、もともとのlivedoorニュースにないカテゴリのテキストデータを投入したとしても、既存記事とどの程度類似しているかを計算することができます。