OpenSearch には k-NN インデックスをサポートしております。
また、k-NN インデックスに対してコサイン類似度を計算することができます。
そのため、下記のブログのように画像とテキストの Embedding を index に保存して、類似画像を検索することができます。
そこで、ここでは類似画像の検索を行う手順を記録しておきます。
前提
- Amazon OpenSearch Service を用意しておきます。(ローカルOpenSearchやElasticSearchでもk-nn関連の機能がサポートされていたりプラグインがあれば動くと思います)
- 今回は Sagemaker の notebook 環境で行いましたが、後述する GPU を使用しない方法でも可能だと思います
- 各種必要なパッケージはインストールしてください。
- (頻繁に機能改善が行われる分野でもあるので)間違いや修正した方が良い点があれば気軽にご連絡ください。
OpenSearch で類似画像検索の手順
-
今回は"cats_vs_dogs"データセットから画像データとラベルをダウンロードします
from datasets import load_dataset # データセットをロード dataset = load_dataset("cats_vs_dogs") dataset
-
OpenSearch の index に格納する画像データをデータセットから抽出します
# 犬と猫の画像をそれぞれ10枚ずつ取得 cat_images = [dataset["train"][idx]["image"] for idx in range(10) if dataset["train"][idx]["labels"] == 0] dog_images = [dataset["train"][idx]["image"] for idx in range(20000, 20010) if dataset["train"][idx]["labels"] == 1]
dog_images[0]などで実際に画像データを確認できます。
-
画像から Embedding を抽出します
import torchvision.transforms as transforms import torch import torchvision import timm # load model of EfficientNetb0 model = timm.create_model('efficientnet_b0', pretrained=True) model.classifier = torch.nn.Identity() model.eval() # preprocesser preprocess = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def make_embed(image, preprocess=preprocess, model=model): inputs = preprocess(image) inputs = inputs.unsqueeze(0) embed = model(inputs) # 正規化 embed = embed / torch.linalg.norm(embed, dim=1, keepdim=True) embed = embed[0, :] #(1, 1280) -> (1280) embed_list = embed.detach().cpu().numpy().tolist() return embed_list # create embedding and save embedding to list embed_dogs = [] embed_cats = [] for i in range(10): with torch.no_grad(): dog_img = dog_images[i] cat_img = cat_images[i] dog_embed = make_embed(dog_img) cat_embed = make_embed(cat_img) embed_dogs.append(dog_embed) embed_cats.append(cat_embed)
今回は Efficientnetb0 を使用します。
model.classifier = torch.nn.Identity()
とすることで、モデルのヘッドであるクラス分類層を削除して、embedding のみ取得することができます。 -
OpenSearch クライアントに接続する
from opensearchpy import OpenSearch, RequestsHttpConnection from requests_aws4auth import AWS4Auth from requests.auth import HTTPBasicAuth host = <host># e.g. my-test-domain.us-east-1.es.amazonaws.com awsauth=HTTPBasicAuth(<マスターユーザー名>, <マスターユーザーのパスワード>) # Create OpenSearch client client = OpenSearch( hosts = [{'host': host, 'port': 443}], http_auth = awsauth, use_ssl = True, verify_certs = True, connection_class = RequestsHttpConnection ) # check to connect OpenSearch client client.info()
今回は OpenSearch-py を OpenSearch クライアントとして使用しましたが、クライアントは何でも良いです。
-
index 作成
index_settings = { "settings": { "index.knn": True, "index.knn.space_type": "cosinesimil" # using cosine similarity }, "mappings": { "properties": { "<任意の名前(今回は「embeddings_name」とします。)>": { "type": "knn_vector", "dimension": 1280 } } } } # create index client.indices.create( index=<index 名>, body=index_settings )
-
indexing
# save animal images. corr_table = {} # using for corr_table's key ids = 1 for i in range(10): dog_embed = embed_dogs[i] cat_embed = embed_cats[i] # body_data of dog dog_embedding_data = { 'ids': ids, 'label': 0, 'embeddings_name': dog_embed } corr_table[ids] = dog_images[i] ids += 1 # body_data of cat cat_embedding_data = { 'ids': ids, 'label': 1, 'embeddings_name': cat_embed } corr_table[ids] = cat_images[i] ids += 1 # indexing client.index( index=<index 名>, body=dog_embedding_data, refresh=True ) client.index( index=<index 名>, body=cat_embedding_data, refresh=True )
-
適当な画像データを用いてコサイン類似度で検索する
# sampling image! input_image = dataset["train"][100]["image"] # create embedding with torch.no_grad(): output_embed = make_embed(input_image) # search query body query = { "size": 20, "query": { "knn": { "embeddings_name": { "vector": output_embed, "k": 20, } } }, } # searching response = client.search( body=query, index=<index 名>, ) # display the result of search for hit in response['hits']['hits']: score = hit['_score'] doc = hit['_source'] print(f"Score: {score}, ID: {doc['ids']}, Label: {doc['label']}")
-> 以下のように検索されました。Label を確認すると、label==1である猫が上位に検索されており、類似度検索ができました!!!
犬と猫は似ている気がするのでスコアが近いですが、きちんと猫の画像で検索すると猫が上位に来ており、正確に検索できてることが確認できました。
Score: 0.68988866, ID: 8, Label: 1 Score: 0.6791921, ID: 6, Label: 1 Score: 0.6593978, ID: 2, Label: 1 Score: 0.65817195, ID: 16, Label: 1 Score: 0.65811276, ID: 10, Label: 1 Score: 0.6543478, ID: 18, Label: 1 Score: 0.6213816, ID: 20, Label: 1 Score: 0.61332446, ID: 12, Label: 1 Score: 0.60994273, ID: 14, Label: 1 Score: 0.58420175, ID: 4, Label: 1 Score: 0.5604887, ID: 17, Label: 0 Score: 0.5598484, ID: 13, Label: 0 Score: 0.5588222, ID: 19, Label: 0 Score: 0.5570838, ID: 3, Label: 0 Score: 0.54466915, ID: 1, Label: 0 Score: 0.53117865, ID: 15, Label: 0 Score: 0.51422846, ID: 5, Label: 0 Score: 0.50965375, ID: 9, Label: 0 Score: 0.5043833, ID: 7, Label: 0 Score: 0.5031408, ID: 11, Label: 0
-
実際に画像を比較します(input_imageとcorr_table[8]を比較してください)
実際に類似してる猫の画像が検索できました!!!
まとめ
ここまでの手順で OpenSearch で類似画像検索ができました。
かなり簡単に検索するところまでいけました。例えば、テキストと画像のマルチモーダルでの類似してるものを検索することも簡単にできそうですね。