6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

医療分野での文埋め込みモデルの比較

Posted at

前書き

ChatGPTなどの大規模言語モデル(LLM)では,Hallucinationが課題の一つです.
医療など内容の正確性が求められる分野では特に重要な課題で,LLMに外部データベースから正確な情報を与えた上で,
生成を行うRetrieval augmentation Generation (RAG)が対策方法の一つになります.
RAGでは,関連する情報を正確に検索する必要があり,文章の正確な意味を反映した埋め込み表現を得ることができる文埋め込みモデルが重要です.

そこで,医療分野の日本語の文章に対して,文埋め込みモデルをSemantic Textual Similarity(STS)タスクで比較・検証してみます.

1. 方法

1.1. 検証対象のモデル

とりあえず目に付いた以下の5つのモデルを使います.
OpenAI/text-embedding-ada-002以外はすべてHuggingface上にあります.
text-embedding-ada-002は有料ですが,今回の検証レベルでは数円程度しか掛かっていません.
いずれのモデルも,学習には医療分野ではない一般的な文章で学習したものと思われます.
(データセットの一部に医療系の文章が含まれている可能性は当然ありますが)

モデル 詳細
sonoisa/sentence-bert-base-ja-mean-tokens-v2 https://huggingface.co/sonoisa/sentence-bert-base-ja-mean-tokens-v2
cl-nagoya/sup-simcse-ja-large https://huggingface.co/cl-nagoya/sup-simcse-ja-large
pkshatech/GLuCoSE-base-ja https://huggingface.co/pkshatech/GLuCoSE-base-ja
intfloat/multilingual-e5-large https://huggingface.co/intfloat/multilingual-e5-large
OpenAI/text-embedding-ada-002 https://platform.openai.com/docs/guides/embeddings

1.2. 使用するデータセット

日本語の医療分野の文章のペアの類似度を評価したこちらのデータセットを利用します.
sociocom/Japanese-Clinical-STS
症例報告や電子カルテテキストから抽出してきた文章のペアの類似度を医学知識のある人が
類似度を0~5の6段階評価したデータセットです.

文章1 文章2 類似度
保定2年5か月を経過するが咬合は安定している 保定治療開始後2年7ヶ月が経過するが,咬合は安定している 4
口側の腸管は拡張し,暗赤色を呈していたが,壊死には陥っていなかった 標本造影では尾側膵管は嚢胞状に拡張していたが、明らかな乳頭状隆起像は描出されなかった 0
これらの症例を報告するとともに,それぞれの臨床像について文献的考察を加え検討した loss型の内耳性難聴であつた.本症例の報告とともに本症候群の難聴成因について文献的考察を行った 0
また,左心耳内には輝度の低い血栓を疑う構造物を認めた また,左心耳には輝度の低い血栓を疑う構造物を認めた 5
8)診断及び治療方針:診断は、下顎骨の過成長による骨格性反対咬合の症例で上顎前歯部の叢生を伴うものであった 7)診断及び治療方針:下顎骨の過成長による骨格性反対咬合の症例で、上顎前歯部の叢生を伴うものであった 5
約3年前より左足底に皮疹出現し漸次拡大したため当科を受診した 約3年前より左足背に皮疹が出現し漸次拡大したため当科を受診した 3

最後のペアは皮疹が足底か足背のどっちに出たかで文字的にはたった1文字の違いですが,類似度は3の評価になっています.("皮疹出現"と"皮疹が出現"の違いもありますが)
確かに同じ病気でも,起きた場所が違うだけでも臨床的には大違いのこともあるでしょうから納得ではありますが,自然言語処理的にはこの違いを捉えるのは中々難しそうですね.
BM25など単語の出現頻度ベースのやり方ではおそらく厳しそうです.

1.3. 検証方法

それぞれの文埋め込みモデルを用いてJapenese Clinical-STSの各ペアからそれぞれの埋め込み表現を得ます.
それらの埋め込み表現のコサイン類似度を求め,Japanese Clinical-STSの類似度(Gold Similarity)との相関係数(Pearson, Spearman)を求めます.

1.4 検証に利用したコード

ライブラリの読み込み

import gc
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from transformers import BertJapaneseTokenizer, BertModel,AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from openai import OpenAI

データセットの読み込み

path = 'your_directory_path'
df = pd.read_excel(path + 'dataset_JA_Clinical_STS.xls',names=['Sentence_A','Sentence_B','Similarity'])
df.head()

sonoisa/sentence-bert-base-ja-mean-tokens-v2(クラスはドキュメントそのままです.)

class SentenceBertJapanese: # HuggingFaceのsonoisa/sentence-bert-base-ja-mean-tokens-v2そのまま
    def __init__(self, model_name_or_path, device=None):
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
        self.model = BertModel.from_pretrained(model_name_or_path)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    @torch.no_grad()
    def encode(self, sentences, batch_size=8):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest", truncation=True, return_tensors="pt").to(self.device)
            
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)
            
            del batch, encoded_input, model_output, sentence_embeddings
            gc.collect()  
            torch.cuda.empty_cache()  

        # return torch.stack(all_embeddings).detach().numpy()
        return torch.stack(all_embeddings)

MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
model = SentenceBertJapanese(MODEL_NAME)

# Encode Sentence_A
sentence_a_embeddings = model.encode(df['Sentence_A'].tolist())
# Encode Sentence_B
sentence_b_embeddings = model.encode(df['Sentence_B'].tolist())

# Calculate cosine similarity
cosine_similarities = torch.nn.functional.cosine_similarity(sentence_a_embeddings, sentence_b_embeddings)

# Add cosine similarities as a new column in df
df['Cosine_Similarity_sbert'] = cosine_similarities.tolist()

cl-nagoya/sup-simcse-ja-large

model = SentenceTransformer("cl-nagoya/sup-simcse-ja-large")
sentence_a_embeddings = model.encode(df['Sentence_A'].tolist())
sentence_b_embeddings = model.encode(df['Sentence_B'].tolist())
# ndarrayをTensorに変換
sentence_a_embeddings = torch.from_numpy(sentence_a_embeddings)
sentence_b_embeddings = torch.from_numpy(sentence_b_embeddings)
cosine_similarities = torch.nn.functional.cosine_similarity(sentence_a_embeddings, sentence_b_embeddings)
df['Cosine_Similarity_sup_simcse'] = cosine_similarities.tolist()

pkshatech/GLuCoSE-base-ja

model = SentenceTransformer("pkshatech/GLuCoSE-base-ja")
sentence_a_embeddings = model.encode(df['Sentence_A'].tolist())
sentence_b_embeddings = model.encode(df['Sentence_B'].tolist())
# ndarrayをTensorに変換
sentence_a_embeddings = torch.from_numpy(sentence_a_embeddings)
sentence_b_embeddings = torch.from_numpy(sentence_b_embeddings)
cosine_similarities = torch.nn.functional.cosine_similarity(sentence_a_embeddings, sentence_b_embeddings)
df['Cosine_Similarity_GLuCoSE'] = cosine_similarities.tolist()

intfloat/multilingual-e5-large

class Embeddingmodel:
    def __init__(self, device=None):
        self.tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
        self.model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _average_pool(self,last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    @torch.no_grad()
    def encode(self, sentences, mode = "query" ):
        # Queryの場合は"query: "を,それ以外の場合は"passage: "を先頭につける必要があるための処理 
        if mode == "query":
            input_texts = ["query: "+sentense for sentense in sentences]
        else:
            input_texts = ["passage: "+sentense for sentense in sentences]
        # Tokenize the input texts
        batch_dict = self.tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt').to(self.device)

        outputs = self.model(**batch_dict)
        embeddings = self._average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

        # normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1) 

        return embeddings

model = Embeddingmodel()
sentence_a_embeddings = model.encode(df['Sentence_A'].tolist(), mode = "query")
sentence_b_embeddings = model.encode(df['Sentence_B'].tolist(), mode = "passage")
cosine_similarities = torch.nn.functional.cosine_similarity(sentence_a_embeddings, sentence_b_embeddings)
df['Cosine_Similarity_e5'] = cosine_similarities.tolist()

OpenAI/text-embedding-ada-002

client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) #環境変数にAPI keyを仕込んでください

sentence_a_embeddings = []
for sentence_a in tqdm(df['Sentence_A']):
    _emb = client.embeddings.create(input = sentence_a, model="text-embedding-ada-002").data[0].embedding
    sentence_a_embeddings.append(_emb)

sentence_a_embeddings = torch.from_numpy(np.array(sentence_a_embeddings))

sentence_b_embeddings = []
for sentence_b in tqdm(df['Sentence_B']):
    _emb = client.embeddings.create(input = sentence_b, model="text-embedding-ada-002").data[0].embedding
    sentence_b_embeddings.append(_emb)

sentence_b_embeddings = torch.from_numpy(np.array(sentence_b_embeddings))
cosine_similarities = torch.nn.functional.cosine_similarity(sentence_a_embeddings, sentence_b_embeddings)
df['Cosine_Similarity_OpenAI'] = cosine_similarities.tolist()

相関係数

pearson_correlation = df[['Similarity', 'Cosine_Similarity_sbert', 'Cosine_Similarity_sup_simcse', 'Cosine_Similarity_GLuCoSE', 'Cosine_Similarity_e5','Cosine_Similarity_OpenAI','BM25']].corr(method='pearson')
display(pearson_correlation)

spearman_correlation = df[['Similarity', 'Cosine_Similarity_sbert', 'Cosine_Similarity_sup_simcse', 'Cosine_Similarity_GLuCoSE', 'Cosine_Similarity_e5','Cosine_Similarity_OpenAI','BM25']].corr(method='spearman')
display(spearman_correlation)

ヒストグラム

fig, axs = plt.subplots(2, 3, figsize=(15, 10))

axs[0, 0].hist(df['Similarity'])
axs[0, 0].set_xlabel('Similarity')
axs[0, 0].set_title('Gold Similarity')

axs[0, 1].hist(df['Cosine_Similarity_sbert'])
axs[0, 1].set_xlabel('Cosine Similarity')
axs[0, 1].set_title('sonoisa/sentence-bert-base-ja-mean-tokens-v2')
axs[0, 1].set_xlim([0, 1])

axs[0, 2].hist(df['Cosine_Similarity_sup_simcse'])
axs[0, 2].set_xlabel('Cosine Similarity')
axs[0, 2].set_title('cl-nagoya/sup-simcse-ja-large')
axs[0, 2].set_xlim([0, 1])

axs[1, 0].hist(df['Cosine_Similarity_GLuCoSE'])
axs[1, 0].set_xlabel('Cosine Similarity')
axs[1, 0].set_title('pkshatech/GLuCoSE-base-ja')
axs[1, 0].set_xlim([0, 1])

axs[1, 1].hist(df['Cosine_Similarity_e5'])
axs[1, 1].set_xlabel('Cosine Similarity')
axs[1, 1].set_title('intfloat/multilingual-e5-large')
axs[1, 1].set_xlim([0, 1])

axs[1, 2].hist(df['Cosine_Similarity_OpenAI'])
axs[1, 2].set_xlabel('Cosine Similarity')
axs[1, 2].set_title('OpenAI/text-embedding-ada-002')
axs[1, 2].set_xlim([0, 1])

fig.tight_layout()
fig.show()

2. 検証結果

2.1 相関係数

早速ですが相関係数を見てみましょう.

PearsonとSpearmanどちらの相関係数でもtext-embedding-ada-002が最も優れていました.
概ね,text-embedding-ada-002 > GLuCoSE-base-ja, multilingual-e5-large > sup-simcse-ja-large, sentence-bert-base-ja-mean-tokens-v2という結果でした.
正解であるデータセットの類似度が6段階の離散値で評価していること,後述の通りコサイン類似度の分布が正規分布ではないことを考えると,Spearmanの順位相関係数で見る方がより適切と思われます.

Spearmanではtext-embedding-ada-002とGLuCoSE-base-jaがほぼ同等のため,無料で中身も触れるGLuCoSE-base-jaが使いやすそうです.
とはいえ,今回の検証でのtext-embedding-ada-002の費用は数円程度のため,かなりヘビーユースしない限り,
気にするほどでもないかもしれませんが...

モデル Pearson Spearman
sonoisa/sentence-bert-base-ja-mean-tokens-v2 0.756214 0.770135
cl-nagoya/sup-simcse-ja-large 0.740156 0.782137
pkshatech/GLuCoSE-base-ja 0.793014 0.821318
intfloat/multilingual-e5-large 0.797392 0.803562
OpenAI/text-embedding-ada-002 0.817546 0.825844

2.2 コサイン類似度の分布

今回,相関係数を見ることが目的でしたが,念のためコサイン類似度の分布を確認しています.
histgram.png

Japanese Clinical-STSの類似度(Gold Similarity)の分布は類似度2がやや少なく,0がやや多いようですが,概ね一様な分布をしています.

いずれのモデルも類似度1 (完全に一致)に偏っていますが,multilingual-e5-largeとtext-embedding-ada-002は0.8~1の狭い範囲に分布しており,
sentence-bert-base-ja-mean-tokens-v2,sup-simcse-ja-large,GLuCoSE-base-jaは比較的広く分布していることが分かります.

今回検証に利用したデータセットのペアは臨床的には意味が違うものも入っていますが,
一般的な文章から見れば医療分野という意味では似ているので,1に偏っているのは当然かもしれません.
なんとなく類似度の分布が広いモデルのほうが,類似度の評価がメリハリが効いて相関係数も高くなりやすいかな?と思っていたのですが,
実際には分布が狭いmultilingual-e5-large,text-embedding-ada-002が比較的性能がよいというのが不思議な気がします.
医療分野という意味で,大体のベクトルの向きは同じだけど,より細かくも捉えられていたということになるんですかね?

3. あとがき

医療分野の日本語文章でのSemantic Textual Similarity(STS)性能はtext-embedding-ada-002が最も優れており,続いてGLuCoSE-base-jaでした.
一般的な文章のSTSタスクではmultilingual-e5-largeが優れている場合もあるようですが,
ここらへんのモデルはほぼ同等程度で検証に用いるデータ次第で変わるレベルなのかもしれません.

素の性能としてはtext-embedding-ada-002がトップでしたが,モデルの中身を触ってさらに追加の訓練などもできることを考えると,他のモデルの方が性能が良くなる可能性も大いにあります.
とりあえずのお試しであればtext-embedding-ada-002,本格的にチューニングするのであれば他のモデルがいいという感じでしょうか.

今回は特に医療分野に限らない一般のコーパスで学習したモデルを用いましたが,医療分野のコーパスで学習したモデルで性能が上がるのか,余力があれば検証してみたいです.
また,類似度の指標として簡単にコサイン類似度を用いましたが,その他の類似度判定のためのアルゴリズムの検証もしてみたいです.

バグやミスがあればお気軽にコメントください.

最後に検証したモデル・データセットの開発者さまに感謝します!

6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?