LoginSignup
1
1

【Azure】E5 から得た埋め込みを Cognitive Search で検索できるようにする

Last updated at Posted at 2023-11-20

Azure Cognitive Search を用いて、手元にあるテキストの埋め込み表現(Embeddings)を使った類似文章検索の機能を実装する機会がありましたので、その方法についてざっくりと紹介します。

必要なライブラリ

  • PyTorch
  • Requests
  • tqdm
  • Transformers

1. E5 から埋め込みを得る

今回は埋め込み表現を Multilingual-E5-large から取得します。

E5(EmbEddings from bidirEctional Encoder rEpresentations)は、文章をその意味や文脈に応じたベクトル(埋め込み表現)に変換するためのモデルです。E5を発表した論文では、様々なタスク・データセットにおいて、既存の文章埋め込みモデルと比べ E5 が良好な性能を示すことが報告されています。

Multilingual-E5-large は、多言語に対応した XLM-RoBERTa-large をベースに、多言語の類似文章ペアのデータセットを加えて事前学習・fine-tuning が行なわれた E5 の学習済みモデルです。日本語にも対応しているので、今回はこのモデルを使って日本語テキストの埋め込み表現を取得します。

今回は、埋め込み表現を取得する元のテキストとして、e-Gov 法令検索からXML形式で配布されている法令データを前処理したテキストを使い、「第一条 ~~」「第二条 ~~」といった条項単位で条項文の埋め込み表現を取得します。
前処理では、条項ごとに改行されたテキストファイル(sentences.txt)を用意しました。このとき、E5の入力トークン長の上限(512トークン)を超える条項文については、直前直後のウィンドウと一定の重複範囲を持たせたスライディングウィンドウにより、行を分割しています。

E5 は学習段階において、文章のペアにそれぞれ “query: “ と “passage: “ という接頭辞を付与した上で学習をしています。推論に際してもこれらの接頭辞を付与することが推奨されているので、それぞれの条項文に対して “passage: “ の接頭辞を付与した上でデータセットを構築します。

import os

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import tqdm
from transformers import AutoTokenizer, AutoModel

os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 環境に応じて変更

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-large")
model = AutoModel.from_pretrained("intfloat/multilingual-e5-large").to(device).eval()

class SentencesDataset(Dataset):
    """
    Datasetには生のテキストを持たせる(この時点ではtokenizeしない)
    E5用に"passage: "を接頭辞として加える
    """
    def __init__(self, sentences: list[str]):
        self.sentences = [f"passage: {sentence}" for sentence in sentences]

    def __getitem__(self, index):
        return self.sentences[index]

    def __len__(self):
        return len(self.sentences)

class SentencesCollator():
    """
    DataLoaderが呼ばれた際にDatasetの生テキストをtokenizeする
    """
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, sentences):
        encoding = self.tokenizer(
            sentences,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return encoding

# 改行区切りのテキストデータを読み込み、DataLoaderを定義する
with open("sentences.txt", "r", encoding="utf-8") as f:
    sentences = tuple(map(lambda text: text.rstrip("\n"), f.readlines()))

dataset = SentencesDataset(sentences)
sentences_collator = SentencesCollator(tokenizer)
dataloader = DataLoader(dataset, collate_fn=sentences_collator, batch_size=4, shuffle=False, num_workers=1)

# 読み込んだテキスト
# 出典:e-Govポータル (https://www.e-gov.go.jp)を基に作成
for sentence in dataset[:3]:
    print(sentence)
# [Output]
# passage: 日本学術会議法 第一条 この法律により日本学術会議を設立し、この法律を日本学術会議法と称する。日本学術会議は、内閣総理大臣の所轄とする。日本学術会議に関する経費は、国庫の負担とする。
# passage: 日本学術会議法 第二条 日本学術会議は、わが国の科学者の内外に対する代表機関として、科学の向上発達を図り、行政、産業及び国民生活に科学を反映浸透させることを目的とする。
# passage: 日本学術会議法 第三条 日本学術会議は、独立して左の職務を行う。一 科学に関する重要事項を審議し、その実現を図ること。二 科学に関する研究の連絡を図り、その能率を向上させること。

# DataLoaderが正しく定義できていそうか確認
batch = next(iter(dataloader))
print(batch.keys())
print(batch['input_ids'].shape)
print(batch['attention_mask'].shape)
# [Output]
# dict_keys(['input_ids', 'attention_mask'])
# torch.Size([4, 114])
# torch.Size([4, 114])

良い感じに DataLoader を定義できたので、E5 に入力して最終層のアベレージプーリングを計算し、埋め込み表現を取得します。

def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

sentence_embedding_pairs = []
for index, batch in enumerate(tqdm.tqdm(dataloader)):
    inputs = sentences[index * dataloader.batch_size : (index + 1) * dataloader.batch_size]
    outputs = model(**batch.to(device))
    embeddings = average_pool(outputs.last_hidden_state, batch["attention_mask"]).detach().to("cpu")
    embeddings = F.normalize(embeddings, p=2, dim=1) 
    # テキストと埋め込みの対応関係を保存する
    sentence_embedding_pairs += list(map(lambda values: {"sentence": values[0], "embedding": values[1]}, zip(inputs, embeddings)))

assert len(sentence_embedding_pairs) == len(sentences)

これで、sentence_embedding_pairs リストの中には、sentence と embedding の2つのkeyをもつ辞書が、条項文の数だけ格納されました。

2. Azure Cognitive Search の操作

ここからは、今回の主役である Cognitive Search を操作します。

Search サービスの作成とインデックスの定義

ポータルで Search サービスを作成する - Azure Cognitive Search | Microsoft Learn を参照しながら Azure Portal を操作します。無料の Free Tier でも類似ベクトル検索を行うことは可能なので、とりあえずのお試しであれば Free Tier で大丈夫です。ただし、ストレージサイズが 50MB しかないので、アップロードするデータサイズに応じて Basic 以上の価格レベルも検討しましょう。

次に、立ち上げた Search サービスに検索インデックスを登録します。
Azure Portal からも登録可能ですが、REST API から操作した方がコードを使いまわしたりできて楽なので、今回は API を介して操作します。

今回は、文章の埋め込み表現を格納する SentenceEmbedding 、埋め込みの基となった文章を格納する Sentence、レコードを一意に特定するためのキーとなる Id の3種類のフィールドを定義しました。
なお、フィールドの定義やリクエストパラメタの設定は、クイックスタート: REST API を使用して検索インデックスを作成する - Azure Cognitive Search | Microsoft Learn や、GitHub - Azure-Samples/azure-search-postman-samples を参考に作成しました。必要に応じて、これらを参照しつつインデックスの定義を変更してください。

import json
import os
import requests

index_name = "<Searchサービスに設定するインデックスの名称。文字種は小文字アルファベット、数字、ダッシュ(-)が利用可能、ただし最初と最後の文字は必ずアルファベット>"
search_service_name = "<Azure Portal上で設定したSeachサービスの名称>"
search_api_version = "2023-11-01"
search_api_key = "<SearchサービスのAPIキー>"

# パラメタの設定
data = {
    "name": index_name,  # index の名称
    "fields": [
        {
            "name": "Id",  # フィールドの名称
            "type": "Edm.String",  # データ型
            "searchable": "false",
            "filterable": "true",
            "retrievable": "true",
            "sortable": "false",
            "facetable": "false",
            "key": "true"  # keyとなるフィールドは必ず1つ含む必要がある
        },
        {
            "name": "Sentence",
            "type": "Edm.String",
            "searchable": "true",
            "filterable": "false",
            "retrievable": "true",
            "sortable": "true",
            "facetable": "false"
        },
        {
            "name": "SentenceEmbedding",
            "type": "Collection(Edm.Single)",  # vector 用のデータ型
            "searchable": "true",
            "retrievable": "true",
            "dimensions": 1024,  # 埋め込み表現の次元数を指定する
            "vectorSearchProfile": "my-vector-profile"  # 以下で定義するベクトル検索の設定
        }
    ],
    "vectorSearch": {
        "algorithms": [
            {
                "name": "my-hnsw-vector-config-1",
                "kind": "hnsw",
                "hnswParameters":
                {
                    "m": 4,
                    "efConstruction": 400,
                    "efSearch": 500,
                    "metric": "cosine"
                }
            }
        ],
        "profiles": [
            {
                "name": "my-vector-profile",
                "algorithm": "my-hnsw-vector-config-1"
            }
      ]
    }
}

# putリクエストでindexを作成
endpoint = f"https://{search_service_name}.search.windows.net/indexes/{index_name}"
headers = {
    "api-key": search_api_key,
    "Content-Type": "application/json"
}
params = {
    "api-version": search_api_version
}

response = requests.put(endpoint, headers=headers, params=params, data=json.dumps(data))
response.raise_for_status()
print(response.status_code)
# [Output]
# 201

埋め込みのアップロード

Search サービスでインデックスの作成ができたら、E5 で取得した埋め込みとそのテキストのペアを格納します。
先に作成した sentence_embedding_pairs の中身を REST API を介して Search サービスに登録していきます。一度にたくさんのデータを登録しようとするとリクエストに失敗するので、今回は2500件ごとにデータを分割して Search サービスへアップロードしました。

import json
import requests
import tqdm

index_name = "<Searchサービスに設定したインデックスの名称>"
search_service_name = "<Azure Portal上で設定したSeachサービスの名称>"
search_api_version = "2023-11-01"
search_api_key = "<SearchサービスのAPIキー>"

endpoint = f"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/index"
headers = {
    "api-key": search_api_key,
    "Content-Type": "application/json"
}
params = {
    "api-version": search_api_version
}

chunk_size = 2500
for loop_iter, index in enumerate(tqdm.tqdm(range(0, len(sentence_embedding_pairs), chunk_size))):
    # split data into chunks
    data = json.dumps({
        "value": [
            {
                "Id": str(index + chunk_size * loop_iter),
                "Sentence": pair["sentence"],
                "SentenceEmbedding": pair["embedding"].tolist(),
                "@search.action": "upload"
            }
            for index, pair in enumerate(sentence_embedding_pairs[index:index+chunk_size])
        ]
    })
    response = requests.post(endpoint, headers=headers, params=params, data=data)
    response.raise_for_status()
    print(response.status_code)
    # [Output]
    # 200

登録した埋め込みを使って検索する

これでベクトル検索の準備は完了です。さっそく類似文章の検索をやってみます。
クエリの埋め込み表現を Multilingual-E5-large で取得し、REST API を介して Search サービス上で類似ベクトルを検索することで、クエリと似た条項文を取得します。

import json
import requests
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

query = "日本学術会議の第三部はどのような会員によって成り立っていますか?"

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

input_texts = [f"query: {query}"]

tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').eval()

# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')

outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().to("cpu")
embeddings = F.normalize(embeddings, p=2, dim=1).squeeze()

assert embeddings.shape[0] == 1024

# postリクエストで検索リクエスト
data = {
    "count": "true",
    "select": "Id, Sentence",
    "vectorQueries": [
        {
            "vector": embeddings.tolist(),
            "k": 5,  # 結果の数
            "fields": "SentenceEmbedding",
            "kind": "vector",
            "exhaustive": "true"
        }
    ]
}

index_name = "<Searchサービスに設定したインデックスの名称>"
search_service_name = "<Azure Portal上で設定したSearchサービスの名称>"
search_api_version = "2023-11-01"
search_api_key = "<SearchサービスのAPIキー>"
endpoint = f"https://{search_service_name}.search.windows.net/indexes/{index_name}/docs/search"
headers = {
    "api-key": search_api_key,
    "Content-Type": "application/json"
}
params = {
    "api-version": search_api_version
}

response = requests.post(endpoint, headers=headers, params=params, data=json.dumps(data))
response.raise_for_status()
print(response.status_code)
# [Output]
# 200

# 検索結果を表示
for result in json.loads(response.text)["value"]:
    print(f"(Score={result['@search.score']:.3f}) {result['Sentence']}")
# [Output] 出典:e-Govポータル (https://www.e-gov.go.jp)
# (Score=0.885) 日本学術会議法 第十一条 第一部は、人文科学を中心とする科学の分野において優れた研究又は業績がある会員をもつて組織し、前章の規定による日本学術会議の職務及び権限のうち当該分野に関する事項をつかさどる。第二部は、生命科学を中心とする科学の分野において優れた研究又は業績がある会員をもつて組織し、前章の規定による日本学術会議の職務及び権限のうち当該分野に関する事項をつかさどる。第三部は、理学及び工学を中心とする科学の分野において優れた研究又は業績がある会員をもつて組織し、前章の規定による日本学術会議の職務及び権限のうち当該分野に関する事項をつかさどる。会員は、前条に掲げる部のいずれかに属するものとする。
# (Score=0.876) 日本学術会議法 第十条 日本学術会議に、次の三部を置く。第一部第二部第三部
# (Score=0.874) 日本学術会議法 第二十三条 日本学術会議の会議は、総会、部会及び連合部会とする。総会は、日本学術会議の最高議決機関とし、年二回会長がこれを招集する。但し、必要があるときは、臨時にこれを招集することができる。部会は、各部に関する事項を審議し、部長がこれを招集する。連合部会は、二以上の部門に関連する事項を審議し、関係する部の部長が、共同してこれを招集する。
# (Score=0.873) 日本学術会議法 第三条 日本学術会議は、独立して左の職務を行う。一 科学に関する重要事項を審議し、その実現を図ること。二 科学に関する研究の連絡を図り、その能率を向上させること。
# (Score=0.871) 日本学術会議法 第七条 日本学術会議は、二百十人の日本学術会議会員(以下「会員」という。)をもつて、これを組織する。会員は、第十七条の規定による推薦に基づいて、内閣総理大臣が任命する。会員の任期は、六年とし、三年ごとに、その半数を任命する。補欠の会員の任期は、前任者の残任期間とする。会員は、再任されることができない。ただし、補欠の会員は、一回に限り再任されることができる。会員は、年齢七十年に達した時に退職する。会員には、別に定める手当を支給する。会員は、国会議員を兼ねることを妨げない。

「日本学術会議の第三部はどのような会員によって成り立っていますか?」という検索クエリに対して、日本学術会議の各部がどのような会員によって成り立つのかを定めた日本学術会議法 第十一条が最も高いスコアで返ってきました。良い感じです。

感想

他にも様々なクエリを用いて類似文章の検索を行ってみましたが、非常に精度よく関連する条項文を抽出できていたような印象です。ローカルで動作し、かつ日本語でこれだけ機能する文章埋め込みモデルがあると、とても重宝しそうです。

また、今回初めて Cognitive Search を使ってみましたが、一度サービスを立ち上げてしまえば REST API で操作でき、データ登録など非常にやりやすかったです。FAISS や Chroma と比べるとあまりベクトルデータベースとしては使われていない印象を持っていましたが、実際に触れてみたところ思っていたよりも手軽に利用できるサービスであることがわかりました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1