概要
Stable Diffusionと類似画像検索を併用することにより、大量の画像から気に入ったものをあたかもキーワード検索しているように使えるのでは、と思いやってみました。
Stable Diffusion用のサーバ
以前やった、Google Colab上のflaskで動かしたサーバをそのまま使用しました。
検索対象の画像について
Windy.comが提供するWebカメラのAPIサービスからとってきた、約6万5000枚のWebカメラのサムネイル画像を検索対象としました。APIでは、サムネイル以外にも、位置情報や動画のURLもとってこれます。
類似検索サーバについて
事前に、6万枚の画像の特徴量のリストと、ファイル名のリストをファイルに保存しています。
以下のような動作です。
1.Base64のイメージを受け取り、受信画像の特徴量を生成。
2.生成した特徴量と、サムネイルの特徴量それぞれに対して、コサイン類似度を計算。
3.類似度が大きいもの上位3件のファイル名、および類似度の値を返す。※並べかえはしていません。
また、ブラウザで確認させるため指定たファイル名の画像を取得する機能もつくりました。
【app.py】
from flask import Flask,request,send_from_directory
from flask_cors import CORS
import torch
from torchvision.models import ResNet50_Weights
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import pickle
from io import BytesIO
import base64
app = Flask(__name__)
CORS(app)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
model.eval()
cos_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
# Define the image preprocessing steps
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
IMAGE_DIR = "c:/temp/images/" # directory where the images are stored
FILE_LIST = "c:/temp/filelist" # list of image files
FEATURE_LIST = "c:/temp/features.pt" # list of image features
topK = 3 # number of similar images to return
# Load the pre-computed image features and file list
features_list = torch.load(FEATURE_LIST )
print(len(features_list))
with open(FILE_LIST, 'rb') as f:
filelist = pickle.load(f)
print(len(filelist))
@app.route('/getSimilarImageFromBase64', methods=['POST'])
def getSimilarImagesFromBase64():
imageString = request.form['base64Image']
base64String = imageString.split(",")[1] # remove the header 'data:image/jpeg;base64,'
# Convert the base64 string to an image
f = BytesIO()
f.write(base64.b64decode(base64String))
f.seek(0)
image = Image.open(f)
image = preprocess(image)
image = image.unsqueeze(0).to(device)
# Get the image features
with torch.no_grad():
queryFeature = model(image)
score_list = [] # list of cosine similarity scores
result_files = [] # list of image file names
# Compare the query image with all the images in the dataset
for i, feature in enumerate(features_list):
# Compute the cosine similarity score
score_t = cos_sim(queryFeature, feature)
score = score_t.item()
# Add the score to the list of scores
if len(score_list) < topK :
score_list.append(score)
result_files.append(filelist[i])
else:
# Replace the smallest score with the current score
# if current score is larger
if min(score_list) < score:
min_idx = score_list.index(min(score_list))
score_list[min_idx] = score
result_files[min_idx] = filelist[i]
# Create a dictionary of the results
resutltObj = {}
for i, file in enumerate(result_files):
resutltObj[file] = score_list[i]
return resutltObj
# Download the image file
@app.route('/download/<string:filename>', methods=['GET'])
def download(filename):
IMAGE_DIR = "c:/temp/images/"
return send_from_directory(IMAGE_DIR, filename, as_attachment=True,mimetype = "image/jpeg")
if __name__ == "__main__":
app.run(debug=True)
確認用HTML
create Imageボタンで、画像生成。search Imageボタンで、画像検索を行います。
<html>
<head>
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
</head>
<body>
<input id="prompt" type="text" style="width:500px" value="" /><BR>
<input type="hidden" id="base64Image" value="" />
<button type="button" id="createImage" onclick="sendPrompt()" />create image</button>
<BR>
<img id="createImg" src="" width="200px" height="200px">
<BR>
<button type="button" id="searchImage" onclick="searchImage()" />search image</button>
<BR>
<div style="float: left;">
<label>file1</label><input id="file1" type="text" value="" style="border:0px" readonly/>
<label>similarity</label><input id="similarity1" type="text" value="" style="border:0px" readonly/>
<div>
<img id="result1" border="0" src="" style="display:none" />
</div>
<label>file2</label><input id="file2" type="text" value="" style="border:0px" readonly/>
<label>similarity</label><input id="similarity2" type="text" value="" style="border:0px" readonly/>
<div>
<img id="result2" border="0" src="" style="display:none" />
</div>
<label>file3</label><input id="file3" type="text" value="" style="border:0px" readonly/>
<label>similarity</label><input id="similarity3" type="text" value="" style="border:0px" readonly/>
<div>
<img id="result3" border="0" src="" style="display:none" />
</div>
</div>
<script>
async function sendPrompt(){
let prompt = document.getElementById('prompt').value ;
const obj = {"prompt" : prompt};
const method = "POST";
const body = Object.keys(obj).map((key)=>key+"="+encodeURIComponent(obj[key])).join("&");
const headers = {
'Content-Type': 'application/x-www-form-urlencoded'
};
const url = "<stable diffusionのサーバ>"
fetch(url,
{method, headers, body})
.then( (res) => res.text() )
.then((text) => {
document.getElementById('createImg').src = text;
document.getElementById('base64Image').value = text;
})
.catch(error=>{
console.log(error);
});
}
function searchImage(){
let base64Image = document.getElementById('base64Image').value ;
const obj = {"base64Image" : base64Image};
const method = "POST";
const body = Object.keys(obj).map((key)=>key+"="+encodeURIComponent(obj[key])).join("&");
const headers = {
'Content-Type': 'application/x-www-form-urlencoded'
};
const url = "http://localhost:5000/getSimilarImageFromBase64"
fetch(url,
{method, headers, body})
.then( (res) => res.json() )
.then((data) => {
resultSet(data);
})
.catch(error=>{
console.log(error);
});
}
function setImage(imgid,simid,image,similarity,fileid) {
$(imgid).attr("src", 'http://localhost:5000/download/' + image);
$(imgid).show();
$(fileid).val(image);
$(simid).val(similarity);
}
function resultSet(data) {
let jsonObj = data;
console.log(jsonObj);
let keys = Object.keys(jsonObj);
setImage("#result1","#similarity1", keys[0], jsonObj[keys[0]],"#file1");
setImage("#result2","#similarity2", keys[1], jsonObj[keys[1]],"#file2");
setImage("#result3","#similarity3", keys[2], jsonObj[keys[2]],"#file3");
}
</script>
</body>
</html>
結果
試してみた結果です。1番上が、Stable Diffusionで作成した画像。下の3つが検索した結果です。