1. 初めに
以前、ローカルでRAGモデルを作成した際に「最近の文章埋め込みモデルはどのような仕組みなのだろう?」という疑問から埋め込みモデルの調査をしていたところ、
- 多言語対応
- いろいろな埋め込み方式について知ることが出来る
- 短文から長文の埋め込みで高性能を出す手法
などを学べるという理由から、M3-Embeddingに関する論文を読んでみました。
また、RAGモデルを作成した際の候補となったというのも大きいです。
論文リンクです。
2. 論文概要
M3-Embedding という新しい文章埋め込みモデルの提案に関する論文です。
M3-Embedding のM3は: Multi-Linguality, Multi-Functionality, Multi-Granularityの3つのMです。それぞれ下記のような意味のようです。
- Multi-Linguality: 多言語(100以上)
- Multi-Functionality: Dense, Sparse, Multi vecの3つの検索方式
- Multi-Granularity: Sentence level, Passage level, Document levelなど様々なトークン長の文章に対応(最大で8,192トークン)
要するにM3-Embeddingは、多言語対応で様々な検索方式を使い分け、短文から長文まで幅広い文章長に対応できる汎用的ないい感じの文章埋め込みモデルということです。
基本的には、この3つを達成するための手法を提案したというのが本論文の新規性と考えるとよいかと考えています。
3. そもそも文章埋め込みモデルとは
そもそも文章埋め込みとは、どのようなものか自分自身も深く理解できていなかったので整理したいと思います。
どのような技術か
文章埋め込みモデルは、文章全体の意味を固定長の数値ベクトルに変換する技術です。
具体的にやりたいこと
「最近は暑くて溶けそうだ」と「ここ1か月は気温が高くてきつい」のような、表現は異なるが意味の似た文章を、数値的に近いベクトルに変換することがやりたいことです。
応用例
- 文章検索: 意味の近い文章を検索する技術です。RAGなどで最近注目を集めているかと思います
- 文章クラスタリング: SNSで自社の口コミ分析などに応用できるかと思います
- 文章分類: ニュース記事などを自動でどの分野かを分類するなどに応用できると思います
4. 文章埋め込みの歴史
今回の論文では、埋め込み手法について理解しておくことが重要です。そのため、埋め込みがどのような発展をしてきたかをまとめてみました。
文章埋め込み技術は、大きく3つのフェーズに分けて発展してきました。
初期:カウントベースの手法
- TF-IDFやBM25などの統計的手法が主流でした
- 単語の出現頻度や文書頻度を基に文章の類似性を計算
- これがM3-EmbeddingでいうところのSparse検索の原型です
中期:単語埋め込みの時代
- Word2VecやFastTextなどの単語レベルの埋め込み手法が登場
- ニューラルネットワークを用いて単語の意味をベクトルで表現
- ただし、単語レベルの埋め込みを文章レベルにどう拡張するかが課題でした
後期(近年):Transformer系の文章埋め込み
- Sentence-BERTのような文章を直接1つのベクトルに埋め込む手法が登場(M3-EmbeddingでいうDense検索)
- ColBERTのように文章内の全トークンのベクトルを保持して類似度を計算する手法(M3-EmbeddingでいうMulti-vector検索)
- これらTransformer系の手法により、文脈を考慮した高精度な文章埋め込みが可能になりました
M3-Embeddingは、これら3つの検索方式(Sparse、Dense、Multi-vector)を統合し、それぞれの長所を活かせるハイブリッドな文章埋め込みモデルです。
さらにまとめると、下記を全て持つモデルとなります。
- Sparse: BM25ライクな検索方法。文章中の単語の出現頻度などを使用
- Dense: Sentence-BERTを使用した手法。アーキテクチャというよりかは学習手法による改善
- Multi-Vector検索: ColBERTライクな手法
5. アーキテクチャと各検索手法の実現方法について
アーキテクチャ全体像
M3 Embeddingのアーキテクチャについて説明します。
M3 EmbeddingはベースモデルにXLM-RoBERTa-largeをRetroMAE実施したモデル使用しています。
RetroMAEの日本語記事は
がありますので、参考になればと思います。
概念図としては、下記となります
トークン化した文章の数だけ隠れ状態ベクトルが出力されます。
Dense検索
結論:アーキテクチャの全体図において、特殊トークンである[CLS]
の出力を使用します。
具体的には、下記の図のH_[CLS]
を文章全体を表すベクトルとして取得する仕組みです。
Sparse(Lexical)検索
Sparse埋め込みは、少し複雑ですが下記の手順で実現します。
-
XLM-RoBERTaからの各トークンの隠れ状態を取得する
-
各トークンの重みを下記のように計算します
$$w_t^q = \text{ReLU}(W_{\text{lex}}^T \cdot H_q[i])$$
ここで:
- $w_t^q$:クエリ内のトークン$t$の重み
- $W_{\text{lex}} \in \mathbb{R}^{d \times 1}$:学習可能な線形変換行列
- $H_q[i] \in \mathbb{R}^d$:$i$番目のトークンの隠れ状態
- $\text{ReLU}$:活性化関数(負の値を0にする)
-
計算後に、同じトークンは重みが最も大きいものに置き換えるという方法でSparse埋め込みを実現します
注意点として、特殊トークンは無視する点が上げられます。
上記を疑似コードで表すと下記のようになります。
import torch
import torch.nn.functional as F
def compute_term_weights(text, model, tokenizer, W_lex):
"""
テキストの各単語の重要度(重み)を計算する
"""
# 1. テキストをトークン化
inputs = tokenizer(text, return_tensors="pt")
# 2. XLM-RoBERTaからの各トークンの隠れ状態を取得する(手順1に該当)
outputs = model(inputs) # [batch_size, seq_len, hidden_size]
# 3. 各トークンの重みを計算(手順2に該当)
term_weights = F.relu(outputs @ W_lex.T).squeeze(-1) # [batch_size, seq_len]
# 4. トークンIDと重みの辞書を作成
weights_dict = {}
for i, token_id in enumerate(inputs["input_ids"][0]):
token_id = token_id.item()
# 特殊トークン([CLS], [SEP], [PAD]など)は無視
if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]:
continue
weight = term_weights[0, i].item()
# 同じトークンが複数回出現する場合は最大値を保持(手順3に該当)
if token_id in weights_dict:
weights_dict[token_id] = max(weights_dict[token_id], weight)
else:
weights_dict[token_id] = weight
return weights_dict
また、類似度スコアは下記のように計算します
- 文章ペアをそれぞれSparse方式でベクトル化する
- それぞれで共通しているトークンの重みを掛け算して足し合わせる
上記を疑似コードで表すと下記のようになります。
def compute_lexical_score(query, passage, model, tokenizer, W_lex):
"""
クエリと文書のLexical(語彙的)類似度スコアを計算
"""
# クエリと文書それぞれの単語重みを取得
query_weights = compute_term_weights(query, model, tokenizer, W_lex)
passage_weights = compute_term_weights(passage, model, tokenizer, W_lex)
# 共通して出現する単語を見つける
common_tokens = set(query_weights.keys()) & set(passage_weights.keys())
# 共通単語の重みの積の合計がスコア
score = 0.0
for token_id in common_tokens:
score += query_weights[token_id] * passage_weights[token_id]
return score
Multi-vecの実現方法
Multi-vec埋め込みは、Dense埋め込みを拡張した手法で、下記の手順で実現します。
-
XLM-RoBERTaからの全トークンの隠れ状態を取得する
-
各トークンの埋め込みを下記のように計算する
$$E_q = \text{norm}(W_{\text{mul}}^T \cdot H_q)$$
$$E_p = \text{norm}(W_{\text{mul}}^T \cdot H_p)$$ここで:
- $E_q, E_p$:クエリと文書の全トークン埋め込み行列
- $W_{\text{mul}} \in \mathbb{R}^{d \times d}$:学習可能な射影行列
- $H_q, H_p \in \mathbb{R}^{N \times d}, \mathbb{R}^{M \times d}$:クエリと文書の隠れ状態
- $\text{norm}$:正規化関数
-
類似度スコアを下記の式で計算する
$$s_{\text{mul}} = \frac{1}{N} \sum_{i=1}^{N} \max_{j=1}^{M} E_q[i] \cdot E_p[j]^T$$
ここで $N, M$ はクエリと文書の長さを表します。
図で表すと、
のような形です。
また、ColBERTの論文内のFigure 2も参考になると思いますので、掲載しておきます。
注意点として、計算コストが高くメモリ効率の悪い点が上げられます
上記を疑似コードで表すと下記のようになります。
import torch
import torch.nn.functional as F
def compute_multi_vector_embeddings(text, model, tokenizer, W_mul):
"""
テキストの全トークンの埋め込みを計算する
"""
# 1. テキストをトークン化
inputs = tokenizer(text, return_tensors="pt")
# 2. XLM-RoBERTaからの全トークンの隠れ状態を取得する(手順1に該当)
outputs = model(inputs) # [batch_size, seq_len, hidden_size]
hidden_states = outputs.last_hidden_state
# 3. 射影変換と正規化(手順2に該当)
embeddings = torch.matmul(hidden_states, W_mul) # [batch_size, seq_len, d]
embeddings = F.normalize(embeddings, dim=-1) # L2正規化
return embeddings
def compute_multi_vector_score(query, passage, model, tokenizer, W_mul):
"""
クエリと文書のMulti-Vector類似度スコアを計算
"""
# クエリと文書の埋め込みを取得
query_embeds = compute_multi_vector_embeddings(query, model, tokenizer, W_mul)
passage_embeds = compute_multi_vector_embeddings(passage, model, tokenizer, W_mul)
# Late Interaction: 各クエリトークンに対して最も類似する文書トークンを見つける(手順3に該当)
scores = torch.matmul(query_embeds, passage_embeds.transpose(-1, -2)) # [1, N, M]
max_scores = torch.max(scores, dim=-1)[0] # 各クエリトークンの最大スコア [1, N]
# 平均スコアを計算
final_score = torch.mean(max_scores).item()
return final_score
Multi-Vector Retrievalの特徴として、各トークン同士の細かい相互作用を考慮できる点があります。これにより、Dense Retrievalでは捉えられない細かな語彙レベルの関連性を検出できます。
hybrid process
上記の3手法のアンサンブル的な手法としてhybrid processについて説明します。
これは、単純に最終類似度スコアの重み付き和で加算したものとなる。具体的な式は下記です。
$$s_{\text{rank}} = w_{\text{1}} \cdot s_{\text{dense}} + w_{\text{2}} \cdot s_{\text{lex}} + w_{\text{3}} \cdot s_{\text{mul}}$$
ここで:
- $w_{\text{1}}, w_{\text{2}}, w_{\text{3}}$:それぞれの手法に対する重み
- $s_{\text{dense}}$:Dense検索の類似度
- $s_{\text{lex}}$:Sparse検索の類似度
- $s_{\text{mul}}$:Multi-Vector検索の類似度
6. 学習方法について
6.1 埋め込みモデルの学習方法について
通常のクラス分類と埋め込みモデルの入出力(入力とラベル)の比較をしたいと思います。
- クラス分類
- 入力と出力:1対1。例えば、犬の画像とラベル(犬、猫、ウサギ、亀、…)
- 出力:固定の長さ(予測したいクラス数分の出力を事前に用意)
- 目的:入力に対して教師と同じ出力を出すこと
という特徴があるので、入力に対して正解ラベルを学習することが簡単です。
コード風に書くと
inputs = 画像 # 入力
labels = 犬 # 出力
outputs = model(inputs)
loss = loss_fn(outputs, labels) # これを小さくするように学習する。negative cross entropyなどを使う
のようになります。
- 埋め込みモデル
- 入力と出力:1対多
- 出力:毎回変化してしまう
- 目的:入力に対して正解の選択肢を選ぶこと
という特徴があるので、正解をモデルに教えることが難しいです。
そこで、埋め込みモデルは
$$queryに対する正解を1つにして、他を全て不正解として扱う$$
という方法を使用することでクラス分類のような方法で学習を可能にしています。
コード風に書くと
query = "パリの観光地は?" # 入力
positive = "パリにはエッフェル塔やルーブル美術館があります" # これを正解にする。
negatives = [
"東京の人口は1400万人です", # 負例1
"プログラミングの基礎を学ぼう", # 負例2
"明日は雨の予報です" # 負例3
]
query_outputs = model(query)
positive_outpus = model(positive)
negative_outpus = model(negatives)
positive_scores = score_fn(query_outputs, positive_outpus) # スコアを高くしたい
negative_scores = score_fn(query_outputs, negative_outpus) # スコアを低くしたい
loss = loss_fn(positive_scores, negative_scores) # これを小さくするように学習する
のような形となります。
上記のloss_fn
を具体的に式で記載すると下記のようになります(InfoNCE
と呼ばれているそうです。)。
$$\mathcal{L}_{s(\cdot)} = -\log \frac{\exp(s(q, p^*)/\tau)}{\sum_{p∈\text({p
∗,P′})} \exp(s(q, p)/\tau)}$$
ここで:
- $\mathcal{L}_{s(\cdot)}$:損失関数(ex. $\mathcal{L}_{s(dense)}$ の場合はDense検索の損失関数)
- $q$:入力クエリ(ex. パリの観光地は?)
- $p^*$:正例(ex. パリにはエッフェル塔やルーブル美術館があります)
- $P′$:不例(ex. 東京の人口は1400万人です, プログラミングの基礎を学ぼう, 明日は雨の予報です)
- $\tau$:温度係数
これを最小化することで埋め込みモデルの学習を行います。
余談:
-
この仕組みを一般的に対照学習と呼びます
-
下記のpytorchで実装されたcross entropyと比較してみると面白いかもしれないです
- 下記のリンクの図も分かりやすいと思いますので、参考になるかと思います
6.2 学習フロー
本節では、学習のフローについて説明します。
M3 Embeddingの学習フローは大きく分けて以下の2つに分けられます。
- 教師無し学習:このステージでは
Dense埋め込み
のみを学習する - ファインチューニング:このステージで
Dense埋め込み
、Sparse埋め込み
およびmulti-vector埋め込み
の3つを学習する
全体図としては、論文のFigure 2掲載の図が分かりやすいかと思います。
教師無し学習
基本的な仕組み
教師無し学習は、1.2B(12億)のテキストペアデータセットを使用して実施されます。「単なるテキストペアからどのように学習データを作るのか?」という疑問が出ると思います。こちらは、以下のように考えると理解しやすいかもしれないです。
$$\text{前提:queryに対して、1つの正例と複数の負例が欲しい}$$
$$\text{方法:1つのテキストペア(queryと正例)にとって、他のテキストペアは負例となる}$$
たとえば、下記のようなペアがあるとします。
- パリの観光地は? <=> パリにはエッフェル塔やルーブル美術館があります
- 東京の人口は? <=> 東京の人口は1400万人です
- プログラミング初心者におすすめの本は? <=> プログラミングの基礎を学ぼう
- 明日の大阪の天気は? <=> 明日は雨の予報です
このとき、パリの観光地は? <=> パリにはエッフェル塔やルーブル美術館があります
のペアにとって
正例:パリにはエッフェル塔やルーブル美術館があります
負例:その他(東京の人口は1400万人です、プログラミングの基礎を学ぼう、明日は雨の予報です)
という風に見ることが出来るはずです。
この仕組みを利用することにより、バッチ内で「クエリ」に対する「正例」と「負例」を作成することでInfoNCE
を計算します。
データセットのについて
データセットは、
- Wikipedia記事の「タイトル」<=> 「本文」
- 論文の「タイトル <=> abstruct」
- 指示 <=> 応答
のように関連性の高いペアで構成されています。
また、クロスリンガル検索用に下記のような翻訳データセットも使用されています(ccmatrixデータセットより抜粋)。
"nl": "En we moeten elke waarheid vals noemen die niet minstens door een lach vergezeld ging.”",
"en": "And we should call every truth false which was not accompanied by at least one laugh.”"
具体的なデータセットは下記です(Table 8より抜粋)。
Data Source | Language | Size |
---|---|---|
MTP | EN, ZH | 291.1M |
S2ORC, Wikipedia | EN | 48.3M |
xP3, mC4, CC-News | Multi-Lingual | 488.4M |
NLLB, CCMatrix | Cross-Lingual | 391.3M |
CodeSearchNet | Text-Code | 344.1K |
Total | – | 1.2B |
ファインチューニング
基本的な仕組み
ファインチューニングは以下のような設定で実施されます。
$$\text{1つのクエリに対して、正例1件、負例7件}$$
こちらに対して、InfoNCE
を計算することで学習を行います。
この段階でdense埋め込み
, sparse埋め込み
, multi-vector埋め込み
を学習するのですが、「これらを同時に学習してしまうと競合により性能が悪化する可能性」があります。
こちらを防ぐために、self-knowledge distillation
という仕組みが導入されています。
self-knowledge distillationについて
こちらは、簡単に言うと「普通の埋め込み学習時の損失」に加えて「3手法を重みづけしたスコアを正とした損失」を追加することで「各手法のバランスを保つ」という手法となります。
まず、通常の埋め込みの損失について説明します。
普通の埋め込み学習時の損失は、加重和スコアを下記の式で表すと
$$s_{inter}=w_{1} \cdot s_{dense} + w_{2} \cdot s_{lex} + w{3} + s_{mul}$$
下記のように書けます。
$$\mathcal{L} = (\lambda_1 \cdot \mathcal{L_{dense}} + \lambda_2 \cdot \mathcal{L_lex} + \lambda_3 \cdot \mathcal{L_mul} + \mathcal{L_inter}) / 4 \tag{1}$$
ここで:
- $s_{inter}, s_{dense}, s_{lex}, s_{mul}$:加重和スコア, denseスコア, sparseスコア, multi-vectorスコア
- $w_{1}, w_{2}, w_{3}$:各スコアに対する係数
- $\mathcal{L}, \mathcal{L_{dense}}, \mathcal{L_lex}, \mathcal{L_mul}, \mathcal{L_inter}$ :全体損失, dense損失, sparse損失, multi-vector損失および加重スコア損失
- $\lambda_1,\lambda_2,\lambda_3$:各手法に対する損失係数
です。
論文では、下記が設定されています。
- $w_{1}=1.0, w_{2}=0.3, w_{3}=1.0$
- $\lambda_1=1,\lambda_2=0.1,\lambda_3=1$
次に、この損失に$s_{inter}$を正解とした各手法の損失を下記のように定義すると、
$$\mathcal{L_{*}^{'}} = -p(s_{inter}) * \log p(s_{*}))$$
$$\mathcal{L^{'}} = (\lambda_1 \cdot \mathcal{L_{dense}^{'}} + \lambda_2 \cdot \mathcal{L_{lex}^{'}} + \lambda_3 \cdot \mathcal{L_{mul}^{'}}) / 3 \tag{2}$$
のような損失が計算できます。
ここで、$p$はソフトマックス関数です。
式(1)と式(2)を用いると
$$\mathcal{L_{final}} = (\mathcal{L} + \mathcal{L^{'}}) / 2 \tag{3}$$
のように書けます。
式(3)を最終的な損失とすることで、「全ての手法の性能を落とすことなく、バランスの良い学習」を実現することが可能となっています。
データセットのについて
使用データセットは下記です。
Data Source | Language | Size |
---|---|---|
MS MARCO, HotpotQA, NQ, NLI, etc. | EN | 1.1M |
DuReader, T2-Ranking, NLI-zh, etc. | ZH | 386.6K |
MIRACL, Mr.TyDi | Multi-Lingual | 88.9K |
MultiLongDoc | Multi-Lingual | 41.4K |
MultiLongDocは合成データで、下記のような手順で作成されているそうです。
- Wikipedia、Wudao、mC4データセットから長文記事をサンプリング
- サンプリングした記事の段落からGPT3.5により質問を生成
- 質問と記事をテキストペアにする
7. 実験結果
実験の概要を下記にまとめました。
-
下記の4実験を実施
- マルチリンガル検索
- クロスリンガル検索
- マルチリンガルの長文検索
- ablation study
- Self-knowledge distillationの有効性
- マルチステージ学習の有効性
-
M3-Embeddingモデル: 下記の5つの組み合わせを検討
- Dense
- Sparse
- Multi-vec
- Dense + Sparse
- All(Dense + Sparse + Multi-vec)
-
比較モデル
- BM25:Sparse検索
- TF-IDFの改良版。ElasticSearchなどで使用されている
- mDPR:Dense検索
- 2023年登場
- Multilingual Dense Retireval Modelの略
- 詳しくは、こちら
- mContriever:Dense検索
- 2022年登場
- Metaの研究
- 詳しくは、こちら
- $\text{mE5}_{\text{large}}$:Dense検索
- 2022年登場
- Microsoftの研究
- M3-Embeddingと学習方法がよく似ているので参考になるかもしれない
- 詳しくは、こちら
- $\text{E5}_{\text{mistral-7b}}$:Desnse検索
- 2023年登場
- Microsoftの研究
- 事前学習済みの言語モデルであるMistral 7Bを使用
- 合成データのみで学習しているそうなので面白いかも
- jina-embeddings-v2-base-en
- 英語のみ対応
- 長文検索で高い性能なので長文検索評価時のみ使用
- OpenAIのモデル
- Text-Embedding-3-Large
- text-embedding-ada-002(長文検索で高い性能なので長文検索評価時のみ使用)
- BM25:Sparse検索
7.1 マルチリンガル検索
データセットと評価方法
結果
Dense検索に関して
-
ほとんどの言語におけるスコアおよび全体の平均値で提案手法が他のモデルを上回る性能となる
-
$\text{E5}_\text{mistral-7b}$と比較して
-
英悟において同程度の性能(わずかに性能が低い)
-
他の言語でははるかに高い性能となる
-
Sparse検索に関して
- Sparseモデルの一種であるBM25と比較して全ての言語で提案手法の性能が高い
組み合わせ
- Dense + Sparseで性能が向上
- 全て組み合わせた結果が最も高性能となる
この実験からは、下記のような知見が得られると考えています。
- $\text{E5}_\text{mistral-7b}$のようなかなり大きいモデルにも性能面で勝つことが出来る
- => 大きなモデルほど性能が高いということが必ずしも成立するわけではない
- => タスクに応じて適切な学習方法やデータを選択することが大事だということ
7.2 クロスリンガル検索
データセットと評価方法
-
データセット:queryと検索対象で別々のデータを使用
-
タスク:25言語のクエリで検索対象を検索
-
評価指標:Recall@100を使用
結果
基本的には、マルチリンガル検索の結果と同様の下記のような結果が得られました。
-
Denseの時点でかなり高いスコア
-
全て組み合わせると最高スコア
ただし、このベンチマーク特有の結果も得られたようです。
-
他のモデルでは、平均性能はそこそこ高いが一部言語で低いスコアとなる(具体的には下記)。一方で、M3-Embeddingでは安定した性能がでる
- ar: アラビア語
- km: クメール語
- he: ヘブライ語
-
Sparse検索がBM25よりも良い結果である。しかし、他の実験と比較してDense, Multi-vec, Dense + Sparse, ALLとの乖離が激しい
クエリと検索対象文章が異なる言語なので、単語の出てくる頻度を使用するSparse検索の性能が低くなるのではないか?と記載されています。要するに、クエリと検索対象が別言語なので同じ単語が出現しにくいためではないかという意味かと思われます。
この実験からは、下記のような知見が得られると考えています。
- クロスリンガル検索ではSparse検索は性能が高くならない
- => 検索の仕組みを考えると納得できる
- => アルゴリズムの中身を理解して使用することが大切であること
7.3 マルチリンガルの長文検索
データセットと評価方法
-
データセット: 2つのベンチマークを使用
-
MLDR(Multi-lingual Long-Doc Retrieval) => Wikipedia、Wudao、mC4から収集した多言語記事で構成
-
NarrativeQA => 英語のみのデータセット
-
-
タスク: 多言語(MLDR)と英語(NarrativeQA)での文章検索
-
評価指標: nDCG@10を使用
結果
MLDRとNarrativeQAで分けて説明します。
MLDR
以下のような特徴がみられます。
- Sparse方式による検索が単一手法では一番性能が高い
- Denseと比較すると10ポイント以上の性能差
- 次いでMulti-vecとなっている
- 同様なSparse方式であるBM25も高い性能を発揮していることがわかる
- 他の実験と同様に全ての手法を組み合わせると最も性能が高くなる
- アブレーション分析として「ファインチューニング時に長いドキュメントを除去」を実施すると下記のような結果が得られた
- 長いドキュメントを除去すると、元の性能から10ポイント程度低下することが分かる。 => 長いドキュメントを学習させることの重要性の証明
- Denseにおいてほとんどの手法で性能が高い(平均で41.2程度) => 事前学習の段階で長文処理の能力をある程度取得していると考えられる
- MCLS(ドキュメントを256トークンの単位でチャンク分割して
[CLS]
トークンを挿入し、[CLS]
トークンの隠れ状態の平均値をとる方法)を使用すると、平均で45.0のスコアに上昇する
NarrativeQA
MLDRと同様の傾向が得られます。また、Max Lengthが短いモデルではあまり性能がでないようです。
この実験からは、下記のような知見が得られると考えています。
- 長文を扱う必要があるタスクでは、Sparse検索の有効性高いこと
- => タスクによってはBM25のような古典的なアルゴリズムも十分に役に立つ
- => いつでも最新の手法が活躍するわけではなく、適材適所が大事
7.4 ablation study
この実験は、「Self-knowledge distillation」と「多段学習」の有効性に関する実験です。
Self-knowledge distillationの有効性
学習時のSelf-knowledge distillation(skdと略して記載)を有効化した場合と無効化した場合でマルチリンガル検索の実験結果がどのように変化するかを調査した結果です。
結果から、
- skdを入れた場合:Sparseも性能が高くなる
- skdを入れない場合:Sparseの性能だけが極端に悪くなる
ということが分かります。
このことから、「全ての手法の性能を落とすことなく、バランスの良い学習」を実現できていることが分かります。
多段学習の有効性
上記から、
- RetroMAWが最も性能向上に寄与している
- 教師無し学習を実施することでさらなる精度向上を実現している
ということが分かります。
8. 実際に動かしてみる
色々述べてきましたが、やはり動かしてみることが理解の近道だと思いますので、実際に動かしてみましょう
ソースコードは下記です。
まず、必要なライブラリをimportします
import pprint
import numpy as np
import pandas as pd
from FlagEmbedding import BGEM3FlagModel
from transformers import AutoTokenizer
続いて、モデルとtokenizerをロードします(tokenizerについては、元モデルがXML-Roberta-Largeを使用しているそうなので、そちらを使用しています。)
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-large")
簡単な埋め込みを試してみます。
inputs = '私は機械学習アルゴリズムを勉強している'
embedding = model.encode(
inputs,
return_dense=True,
return_sparse=True,
return_colbert_vecs=True
)
print('=' * 50)
pprint.pprint(embedding)
print('=' * 50)
print(f'dense_vec: {embedding["dense_vecs"].shape}, [CLS]の埋め込みベクトル')
print('=' * 50)
print(f'num token: {len(embedding["lexical_weights"])}')
print(f'lexical_weights: {dict(embedding["lexical_weights"])}, tokenごとの重みが入っている')
for token_id, weight in embedding["lexical_weights"].items():
decoded_str = tokenizer.decode(int(token_id))
print(f'weight of {token_id}({decoded_str}): {weight}')
print('=' * 50)
print(f'colbert_vecs: {embedding["colbert_vecs"].shape}, tokenごとの重みが入っている(恐らく、[CLS]のものも含んでいる?)')
print('=' * 50)
出力は、下記のようになります。
==================================================
{'colbert_vecs': array([[-1.25820469e-02, -1.68223102e-02, -2.11406145e-02, ...,
1.68656670e-02, 2.47999001e-02, 1.58858094e-02],
[-5.84731577e-03, -3.35712358e-02, 2.85127796e-02, ...,
2.68597087e-05, 4.02074121e-02, -1.27925277e-02],
[-3.02599519e-02, -3.30640301e-02, 1.42784305e-02, ...,
-3.01567335e-02, -2.67333519e-02, 3.57820909e-03],
...,
[-3.21515314e-02, 3.26354895e-03, -1.04450397e-02, ...,
-1.02367287e-03, 1.03019569e-02, -4.77643870e-03],
[-1.91142410e-02, -7.63488840e-03, -2.83934474e-02, ...,
2.47496981e-02, 3.30407694e-02, 1.39188105e-02],
[-1.00769214e-02, -1.48883052e-02, -3.07244081e-02, ...,
3.08297183e-02, 2.16413569e-02, 1.54675143e-02]],
shape=(10, 1024), dtype=float32),
'dense_vecs': array([-0.0332 , 0.01227 , -0.03906 , ..., -0.00479 , 0.01143 ,
-0.0001532], shape=(1024,), dtype=float16),
'lexical_weights': defaultdict(<class 'int'>,
{'107528': np.float16(0.1737),
'223867': np.float16(0.1641),
'251': np.float16(0.10046),
'40554': np.float16(0.2423),
'40601': np.float16(0.1482),
'4130': np.float16(0.2083),
'50866': np.float16(0.1576),
'65579': np.float16(0.27),
'66281': np.float16(0.0764)})}
==================================================
dense_vec: (1024,), [CLS]の埋め込みベクトル
==================================================
num token: 9
lexical_weights: {'65579': np.float16(0.27), '50866': np.float16(0.1576), '107528': np.float16(0.1737), '66281': np.float16(0.0764), '40601': np.float16(0.1482), '223867': np.float16(0.1641), '251': np.float16(0.10046), '40554': np.float16(0.2423), '4130': np.float16(0.2083)}, tokenごとの重みが入っている
weight of 65579(私は): 0.27001953125
weight of 50866(機械): 0.1575927734375
weight of 107528(学習): 0.1737060546875
weight of 66281(アル): 0.076416015625
weight of 40601(ゴ): 0.148193359375
weight of 223867(リズム): 0.1640625
weight of 251(を): 0.1004638671875
weight of 40554(勉強): 0.2423095703125
weight of 4130(している): 0.208251953125
==================================================
colbert_vecs: (10, 1024), tokenごとの重みが入っている(恐らく、[CLS]のものも含んでいる?)
=================================================
こちらから、以下のようなことがわかります。
- Dense埋め込み:[CLS]の埋め込みベクトル(1024次元)を使用している
- Sparse埋め込み:トークンごとの重みを保持している(XLM-Robetaのトークナイザーで上手くデコードできているので、読みは正しそうです。)
- multi-vector埋め込み(colbert_vecs):(トークン数, 埋め込み次元)のベクトルとなっている。 => トークンごとの埋め込みベクトルが格納されている
続いて、マルチリンガル検索とクロスリンガル検索を試してみます。
まず、マルチリンガル検索です。
検索を簡易にするためのクラスを定義しておきました。
class SearchWithM3Embedding:
def __init__(
self,
model: BGEM3FlagModel,
weights_for_different_modes: tuple[float]=[1.0, 0.3, 1.0] # スコア計算時のweight
):
self.model = model
self.weights_for_different_modes = weights_for_different_modes
def search(self, query: list[str], targets: list[str], k=10):
"""queryに対してそれぞれの検索手法に対して上位k件の結果をデータフレームにして返す"""
score_dict = self.compute_score(query, targets)
result = {}
for method, score_list in score_dict.items():
sorted_indices = np.argsort(score_list, )[::-1]
result[method] = [targets[sorted_id] for sorted_id in sorted_indices]
return pd.DataFrame(result).iloc[:k]
def compute_score(self, query: list[str], targets: list[str]) -> dict[str, list[float]]:
sentence_pairs = [[i,j] for i in query for j in targets]
score_dict = model.compute_score(
sentence_pairs,
max_passage_length=128, # a smaller max length leads to a lower latency
weights_for_different_modes=self.weights_for_different_modes
)
return score_dict
search_with_m3_embedding = SearchWithM3Embedding(model)
まずは、日本語で検索してみましょう
# 日本語 => 日本語
documents = [
"機械学習には、RandomForestやSVMなど様々なモデルがある",
"CNNやViTは画像をクラス分類する技術です。",
"画像認識には、ResNetが有効",
"yoloは物体検出を行うためのモデルである",
"SwinTransformerは画像分類でかなり性能が高い",
"音声認識には、1DのCNNが有効",
"音声認識には、LSTMが有効",
"データベースは情報を効率的に管理するシステムです。"
]
query = ["画像分類に有効なモデルは?"]
search_with_m3_embedding.search(query, documents)
結果は、
のような形で概ね良好です。
続いて、英語、中国語を試します。基本的に、先ほどのプログラム例を翻訳したものを入力しているので、結果のみを計算します。
英語の結果です。概ねいい感じだと思います。
中国語の結果です。sparseの結果では、「SwinTransformerは画像分類でかなり性能が高い」に該当する結果が低い順位にきています。中国語は分かりませんが、queryと若干違う単語が多いのだと思います。
他にも、「フランス語」、「ロシア語」、「アラビア語」、「クメール語」、「ヘブライ語」、「韓国語」を試しています。興味があればコードをみてみるとよいかもです。
続いて、クロスリンガル検索を実施しました。
# 英語
documents = [
"機械学習には、RandomForestやSVMなど様々なモデルがある",
"CNNやViTは画像をクラス分類する技術です。",
"画像認識には、ResNetが有効",
"yoloは物体検出を行うためのモデルである",
"SwinTransformerは画像分類でかなり性能が高い",
"音声認識には、1DのCNNが有効",
"音声認識には、LSTMが有効",
"データベースは情報を効率的に管理するシステムです。"
]
query = ["What models are effective for image classification?"]
search_with_m3_embedding.search(query, documents)
上記のように、検索対象:日本語、クエリ:他言語 という構成となります。
同じように「英語」と「中国語」の結果を示します。
英語で検索した結果です。
sparse以外は良好です。sparse検索の性能が低い理由は、論文でも述べられていた通りで「共通する単語がクエリと検索対象に存在しないため」と考えられます。
中国語で検索した結果です。概ね良い結果で、sparseについては英語と同様に性能が低いです。
他にも、「フランス語」、「ロシア語」、「アラビア語」、「クメール語」、「ヘブライ語」、「韓国語」を試しています。興味があればコードをみてみるとよいかもです。
総じて言えるのは、「クロスリンガル検索では、sparse検索の性能が低くなる」という論文で示唆されていたことが起きているということです。
9. 感想
今回はM3 Embeddingの論文を読んでみました。
感想としては、
- 埋め込みモデルの分野において、BERT系は以前として強い
- タスクによっては、BM25などの非深層学習系のアルゴリズムがかなり高性能になることがある
- タスクに応じて適切なアルゴリズムを選ぶことが重要
- アルゴリズムの中身を理解して技術選定をすることが重要
- 最近は、アーキテクチャというよりかは学習方法やデータの工夫で性能を上げる論文が多い印象を受けた
- 一口に検索といっても手法が複数あるので、その手法について知ることが出来てよかった
- 概要としてしか理解していなかった対照学習についても学べた
- 実際に動かしてみることで多言語対応の面白さを感じた
などがあります。
総じて、得るものが多い論文だと感じました。