日本語CLIPでの画像検索
CLIPを使って画像検索をしている記事を参考に、日本語CLIPで動くようにしてみました。試してみたところ画像を増やすとメモリが足りなくなったり、本家CLIPとの使い方の違いなどがあったのでメモとして残しておきます。
日本語CLIPを使った記事はいくつか見つかったのですが、実際に手元の画像を入れて試してみると検索が行えていないことがありました。今回参考にさせていただいた記事のソースコードを手元の画像で試したところ正しく検索されたので、こちらをベースに日本語CLIPに置き換えてみました。
公開されているペット画像のdatasetなどでは「それっぽい」結果が出ているようでも、画像を入れ替えると全然動かないことがあるので、やはり生データで試すのが大事かなと思いました。
参考にさせていただいた記事
CLIPを使って、大量の画像の中から自分が探したい画像をテキストで検索する
http://cedro3.com/ai/clip-search/
こちらの記事とソースコードを参考にさせていただきました。
動作環境
- Ubuntu20.04
- RTX3080
- torch==1.10.2+cu113
- dockerイメージ: nvidia/cuda:11.3.0-devel-ubuntu20.04
requirements.txtはこちら
numpy
tensorboard
matplotlib
tqdm
ftfy
regex
git+https://github.com/openai/CLIP.git
実際のソースコード
日本語CLIPの初期化と画像ファイル名を取得する部分
画像はあらかじめ縦サイズが256になるように縮小してRGBのpng画像としてsample_imageフォルダに突っ込んであります。今回は100,000枚の画像で試しましたが7分弱で実行できました。画像を特徴量変換するところだけ遅いので、最後の部分だけnotebookで実行すればテキストを変えながら素早く試すことができます。
import torch
import numpy as np
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
import glob
from tqdm import tqdm
import japanese_clip as ja_clip
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
# 日本語CLIPの読み込み
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="/tmp/japanese_clip", device=device)
tokenizer = ja_clip.load_tokenizer()
# 画像の前処理
preprocess = Compose([
Resize(224, interpolation=Image.BICUBIC),
CenterCrop(224),
ToTensor()
])
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
# 画像の読み込み
# 検索したい画像がsample_image/以下に格納されているものとする
files = glob.glob('sample_image/*.png')
print('file num', len(files))
files.sort()
画像を分割して特徴量に変換する部分
元記事のソースコードのように画像を全部一度に読み込むとメモリが足りなくなるのでバッチ単位で読み込んで特徴量変換しています。これなら10GB以下のマシンでも動きます。
# 一度に画像を読み込むとメモリが足りなくなるので分割して処理する
batch_size = 256
image_features = []
with torch.no_grad():
images = []
for i, file in enumerate(tqdm(files)):
# 画像ファイルを読み込む
image = preprocess(Image.open(file).convert("RGB"))
images.append(image)
# バッチの単位で変換
if i != 0 and (i%batch_size) == 0:
image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]
if len(image_features) == 0:
image_features = model.encode_image(image_input.float())
else:
image_features = torch.cat([image_features, model.encode_image(image_input.float())], dim=0)
images = []
検索テキストと画像との類似度を求めて結果を表示する部分
ここのtokenize()とget_text_features()のところを変更しています。日本語CLIPだと扱い方が変わるみたいです。ここは以下のサイトのスライドを参考にしました。
参考にさせていただいた記事の結果画像を表示する部分がすごく使いやすくて助かりました。matplotlibで複数画像を表示するのが面倒すぎますね。
# 検索テキスト
text = '黒髪の女性'
print('text = ', text)
# 特徴ベクトルを抽出(テキスト)
with torch.no_grad():
# 日本語CLIPは本家CLIPとはテキストの変換処理が少し違うので注意
encodings = ja_clip.tokenize(text, tokenizer=tokenizer)
text_features = model.get_text_features(**encodings)
text_features /= text_features.norm(dim=-1, keepdim=True)
# COS類似度を計算
text_probs = torch.cosine_similarity(image_features, text_features)
# COS類似度の高い順にインデックスをソート
x = np.argsort(-text_probs.cpu(), axis=0)
# COS類似度が高い順に画像を表示
fig = plt.figure(figsize=(30, 40))
disp_num = 100 # 表示する数
for i in range(disp_num):
index = x[i].item()
filename = files[index]
img = Image.open(filename)
images = np.asarray(img)
ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
image_plt = np.array(images)
ax.imshow(image_plt)
cos_value = round(text_probs[x[i].item()].item(), disp_num)
ax.set_xlabel(cos_value, fontsize=12)
plt.show()
plt.close()
今回使ったのがインターネットから収集した画像なので結果画像は載せていませんが問題なく検索できています。
CLIPが動く環境があれば手元の画像を突っ込むだけで試せるのでお手軽です。ただ画像の枚数は多めのほうが結果がわかりやすいと思います。1000枚以下だと「それっぽいけど、本当に検索できているのかな?」と思うことがありました。CLIPによる画像検索はある程度間違えるので、多めの画像で試して全体の傾向を見てみたほうがよいです。