LoginSignup
2
2

Flask×QdrantベクトルDBでCLIPベースの画像検索をやってみた

Last updated at Posted at 2023-12-24

目次

1.背景

 VMware Explore 2023に参加する機会があり、Retrieval-Augmented Generation(RAG)という外部情報を用いて信頼性を向上するという生成AIモデルが注目されていることを知り、非構造データであってもベクトル化すればマルチモーダルに扱えるというのが当たり前になっていることに衝撃受けた、、、
(今後、LLM×RAGで機密情報を扱える自社GPTなるものをどの会社も作る時代が来るかも?とすら思ったほど)

 一方で、個人的にはサービス化とか本番環境への適応とかを考えるとRAGで生成された回答を鵜呑みにすることないはずなので、ソース情報も何らか分かる形にする必要があり、、、となると「RAG」というより「RAGの中の検索AI」こそ大切だし、むしろ検索AIだけでも実は十分役立つものが作れて事足りているのではないか?とか思えた。なので、検索AIを使う何かを作ってみたかった。

 加えて、最近カメラを買ったこともあり大量の画像がNASに蓄積され続けている状況なので、今回は画像に対して意味検索を行うシステムを作ってみた。

2.やりたいこと

  • [1] テキストベースの関連画像検索

    • キーワードを与えたら、意味的に関連する画像が出力される

usecase1.png

  • [2] 画像ベースの類似画像検索

    • 画像を与えたら、関連する画像が出力される
      (そんな機能いるか?感はあるが・・・)

usecase2.png

3.全体構成

 簡単なwebアプリケーションにはしてみたかったので、WEB/APL/DBをQdrant×Flask×Nginx使ってdockerコンテナ上で動かす構成に。

構成図.png

3-1.処理の流れ

  • [1] テキストベースの関連画像検索

    • 1.クライアントPC上のブラウザでNginxにアクセスして"キーワード"を入力
    • 2.NginxがFlaskにフォワード
    • 3.CLIP学習モデルを用いて与えられたキーワードを512次元の埋め込みベクトル化
    • 4.得られたベクトルをQdrant上のコレクション上で評価して関連する画像のメタデータを取得
    • 5.NASから画像データを取得してクライアントPCに検索結果を出力
  • [2] 画像ベースの類似画像検索

    • 1.クライアントPC上のブラウザでNginxにアクセスして"画像"をアップロード
    • 2.NginxがFlaskにフォワード
    • 3.CLIP学習モデルを用いて与えられた画像を512次元の埋め込みベクトル化
    • 4.得られたベクトルをQdrant上のコレクション上で評価して関連する画像のメタデータを取得
    • 5.NASから画像データを取得してクライアントPCに検索結果を出力

3-2.CLIP 埋め込みベクトル抽出

 キーワード/画像から埋め込みベクトルを抽出するために学習済みのモデルが必要なので、とりあえずHugging Faceで「画像⇔テキスト」のマルチモーダルに対応したモデルを探してみた。色々出てきたが使いやすそうだったので今回はCLIPを使うことに。使い方は以下公式ページ参照。

3-3.Qdrant ベクトルDB

 ベクトルDBは山ほどあるが、主要なものは下記サイトで比較がまとめられていて参考になった。VMware Exploreでも紹介されていたかつOSSかつDeveloper experienceの評価が高め、といった理由で今回はQdrantを使ってみた。
 >> Picking a vector database: a comparison and guide for 2023<<

 Qdrantについての説明は以下の公式ページ参照。Classmethodの解説記事もあって分かりやすい。

ざっくりだが

  • Point:id,埋め込みベクトル,Payloadで構成される。RDBだとレコードに相当するもの。
  • Payload:埋め込みベクトルに対するメタデータ。json形式で自由度高く付与できる。
  • Collection:Pointの集合体。RDBだとテーブルに相当するもの。

という理解。

Qdrant説明図.png

3-3-1.NAS画像データの埋め込みベクトル抽出

 事前準備としてNAS上の画像データから、ベクトルDBにインサートするための埋め込みベクトルを抽出しておく。

import os

import torch
import numpy as np
import pickle
import open_clip
from PIL import Image
from qdrant_client.http.models import PointStruct

def load_image(file_path):
    try:
        image = Image.open(file_path)

        return image
    except (IOError, SyntaxError) as e:
        print(f"Failed to load image: {e}")
        return None

def process_image(file_path):
    if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
        loaded_image = load_image(file_path)
        return loaded_image
    else:
        print(f"skip: {file_path} (not an image file)")

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

folder_path_list = os.listdir("./")
img_cnt = 1
point_list = list() # embedded point_list for qdrant

### Directory Structure ###
#./
#└ yyyymmdd_XXXX
# └ Qdrant
#  └ developped
#    └ xxx.jpg, xxx.png
#   └ other
#    └ xxx.jpg, xxx.png

for i,folder_id in enumerate(folder_path_list):
    if folder_id != "Readme_qdrant-img-search.txt":
        for img_type in ["developped","other"]:
            relative_path = "./"+folder_id+"/"+"Qdrant/"+img_type+"/"
            img_list = os.listdir(relative_path)
        for img_id in img_list:
            img_path = relative_path+img_id
            image = process_image(img_path)

            with torch.no_grad(), torch.cuda.amp.autocast():
                image_features = model.encode_image(preprocess(image).unsqueeze(0))
                image_features /= image_features.norm(dim=-1, keepdim=True) #normalization
                
                point = PointStruct(id=img_cnt, vector=image_features[0], payload={"path":img_path})
                point_list.append(point)
                #print(img_cnt,len(point_list))
        
            img_cnt += 1

f = open('img_vec_PointStruct_list.txt', 'wb')
pickle.dump(point_list, f)

3-3-2.Qdrant Collection作成、Pointインサート

 dockerでQdrantコンテナlocalhost:6333で待ち受けるように設定して起動。起動したコンテナに対して、pythonスクリプトでCollection「img_to_vec_clip_collection」を作成する。Pointとしてインサートするベクトルの次元やベクトル空間の距離尺度もここで予め指定する仕様になっている。

from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

client = QdrantClient(host="localhost", port=6333)#, https=False)

client.create_collection(
    collection_name= "img_to_vec_clip_collection",
    vectors_config=VectorParams(size=512, distance=Distance.COSINE)
)

Collectionが作成できたらPointをインサートする。

import pickle
from qdrant_client.http.models import PointStruct

IMG_DIR_PATH = '[Set according to your environment]'
f = open("img_vec_PointStruct_list.txt","rb")
loaded_point_list = pickle.load(f)

operation_info = client.upsert(
    collection_name="img_to_vec_clip_collection",
    wait=True,
    points=loaded_point_list,
)
print(operation_info)

3-4.Flask Webアプリケーション

 ブラウザからアクセスしてqdrant/clipのapiが呼び出せるようにwebアプリケーション化する。各種構成ファイルは以下の通り。

アプリケーションディレクトリ構成
/qdrant_project
└app.py
└static
 └styles.css
 └img_dir_link
└templates
 └index.html
 └search-by-text.html
 └search-by-image.html

※img_dir_linkは/mnt/yへのHyperLink
  /mnt/yでは[NAS IPaddress]:[画像格納フォルダパス]をマウント
app.py
from flask import Flask, render_template, request, send_from_directory
import torch
import open_clip
from qdrant_client import QdrantClient

from PIL import Image
from io import BytesIO

app = Flask(__name__)

#画像置き場
IMG_BASE_PATH = "img_dir_link/[Set according to your environment]"

#CLIPモデルロード
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

#ベクトルDBコンテナへの接続
client = QdrantClient(host="qdrant", port=6333)

#テキストベースの関連画像検索
@app.route('/search-by-text', methods=['GET', 'POST'])
def search_by_text():
    global text_features

    if request.method == 'POST':
        keyword = request.form['keyword']
        text = tokenizer([keyword])
        display_top_N = request.form['display_top_N']

        with torch.no_grad(), torch.cuda.amp.autocast():
            text_features = model.encode_text(text)
            text_features /= text_features.norm(dim=-1, keepdim=True)

        image_list = search_images_in_qdrant(text_features, display_top_N)
        template_data = {
            'keyword': keyword,
            'image_list': image_list,
         }
        return render_template('search-by-text.html', **template_data)

    # GETリクエスト時の処理
    return render_template('search-by-text.html', keyword='', image_list=[])

#画像ベースの類似画像検索
@app.route('/search-by-image', methods=['GET', 'POST'])
def search_by_image():
    global image_features

    if request.method == 'POST':
        uploaded_file = request.files['image']
        display_top_N = request.form['display_top_N']

        if uploaded_file.filename != '':
            image = preprocess(Image.open(uploaded_file)).unsqueeze(0)

            with torch.no_grad(), torch.cuda.amp.autocast():
                image_features = model.encode_image(image)
                image_features /= image_features.norm(dim=-1, keepdim=True)

        image_list = search_images_in_qdrant(image_features, display_top_N)
        template_data = {
                'keyword': "test-image",
                'image_list': image_list,
                }
        return render_template('search-by-image.html', **template_data)

    # GETリクエスト時の処理
    return render_template('search-by-image.html')

#QdrantベクトルDB上で埋め込みベクトルを評価
def search_images_in_qdrant(emb_features,display_top_N):
    search_result_list = client.search(collection_name="img_to_vec_clip_collection", query_vector=emb_features[0])
    score_list = list()
    image_paths = list()

    for i in range(len(search_result_list[:int(display_top_N)])):
        score_list.append(search_result_list[i].score)
        image_paths.append(IMG_BASE_PATH + search_result_list[i].payload['path'][2:])
        image_list = list(zip(score_list, image_paths))
    return image_list

#Toppage
@app.route('/', methods=['GET'])
def welcom():
    # welcomページ
    return render_template('index.html')

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0')
index.html
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Search Results</title>
    <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='styles.css') }}">
</head>
<body>
    <h1>Semantic Image Search Web Application</h1>
    <p>NAS上の画像データに対して、画像検索を行います。</p>
    <a href="/search-by-text"><button>テキストベースの関連画像検索</button></a>
    <a href="/search-by-image"><button>画像ベースの類似画像検索</button></a>
</body>
</html>
search-by-text.html
<!-- search.html -->
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Search By Text</title>
    <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='styles.css') }}">
</head>
<body>
    <h1>Text-based Image Search</h1>
    <p>NAS上の画像に対して文字による意味検索を行います。<br>
    キーワードを入力してください。</p>
    <div id="search-form">
        <form action="/search-by-text" method="post">
            <label for="keyword">スコア上位何位まで表示するか(e.g.3):</label>
            <input type="text" id="display_top_N" name="display_top_N" required>
            <br>
            <label for="keyword">キーワード(e.g.bowing man):</label>
            <input type="text" id="keyword" name="keyword" required>
            <br>
            <button type="submit">Search</button>
        </form>
    </div>
    <div id="search-results">
        <h2>Entered Keyword: {{ keyword }}</h2>
        {% for score, img_path in image_list %}
            <h4>Score: {{ score }}</h4>
            <h4>Path: {{ img_path }}</h4>
            <img src="{{ url_for('static', filename=img_path)}}" alt="static path image" style="max-height: 500px; width: auto;">
        {% endfor %}
    </div>
    <!-- トップページへ戻る -->
    <div id="return-top">
        <a href="/"><button>トップページへ戻る</button></a>
    </div>
</body>
</html>
search-by-image.html
<!-- search.html -->
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Search By Image</title>
    <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='styles.css') }}">
</head>
<body>
    <h1>Image-based Simiar Image Search</h1>
    <div id="image-search">
        <p>NAS上の画像に対して画像による類似画像検索を行います。<br>
           画像を選択してください。</p>
        <form action="/search-by-image" method="post" enctype="multipart/form-data">
            <label for="keyword">スコア上位何位まで表示するか(e.g.3):</label>
            <input type="text" id="display_top_N" name="display_top_N" required>
            <br>
            <!-- 画像をアップロードするための入力フィールド -->
            <input type="file" name="image" accept="image/*" onchange="previewImage(this);">
            <!-- アップロードされた画像のプレビュー表示 -->
            <img id="imagePreview" src="#" alt="Uploaded Image" style="max-height: 500px; width: auto; display: none;">
            <!-- 画像をアップロードするボタン -->
            <input type="submit" value="Upload Image">
        </form>
        <script>
            // 画像のプレビューを表示する
            function previewImage(input){
                var preview = document.getElementById('imagePreview');
                if(input.files && input.files.length > 0){
                    var file = input.files[0];
                    var reader = new FileReader();
                    reader.onload = function(e) {
                        preview.src = e.target.result;
                        preview.style.display = 'block'; // プレビューを表示
                    }
                    reader.readAsDataURL(file);
                } else{
                    preview.style.display = 'none'
                }
            }
        </script>
    </div>
    <div id="search-results">
        {% for score, img_path in image_list %}
            <h4>Score: {{ score }}</h4>
            <h4>Path: {{ img_path }}</h4>
            <img src="{{ url_for('static', filename=img_path)}}" alt="static path image" style="max-height: 500px; width: auto;">
         {% endfor %}
    </div>
    <!-- トップページへ戻る -->
    <div id="return-top">
        <a href="/"><button>トップページへ戻る</button></a>
    </div>
</body>
</html>
styles.css
#search-form {
    border-bottom: 2px solid #000;  /* 下線を追加 */
    padding-bottom: 10px;  /* 下線とフォームの間にスペースを追加 */
}

3-5.Nginx リバースプロキシ

 無くても問題ないが一応リバプロを作っておく。nginx設定ファイルには転送先としてflaskのAPLコンテナを設定。

/etc/nginx/nginx.conf
worker_processes 1;
events {
    worker_connections 1024;
}
http {
    upstream flask-app {
        server flask-app:5000;
    }

    server {
        listen 80;
        server_name localhost;
        client_max_body_size 20M;

        location / {
            proxy_pass http://flask-app;
            proxy_set_header Host $host;
            proxy_set_header X-Real-IP $remote_addr;
            proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
            proxy_set_header X-Forwarded-Proto $scheme;
        }
    }
}

3-6.dockerデプロイ

 Flaskのコンテナイメージをビルドするために、Dockerfile、requirements.txtを作成する。
また、コンテナはQdrant→Flask→Nginxという依存関係を考慮して起動するべきなので、docker composeできるようにdocker-compose.ymlも作成する。

アプリケーションディレクトリ構成
/qdrant_project
└ Dockerfile
└ requirements.txt
└ docker-compose.yml
Dockerfile
#ベースイメージの指定
FROM python:3.8
# 作業ディレクトリの指定
WORKDIR /app
# プロジェクトのファイルをコピー
COPY . .
# Flask アプリケーションの依存関係のインストール
RUN pip install --no-cache-dir -r requirements.txt
# Flask アプリケーションの実行
CMD ["python", "app.py"]
requirements.txt
Flask==2.3.3
torch==2.1.1
qdrant-client==1.7.0
Pillow==10.1.0
open_clip_torch==2.23.0
docker-compose.yml
version: '3'
services:
    qdrant:
      image: qdrant/qdrant:latest
      ports:
        - "6333:6333"
      volumes:
        - [ホストOS上のqdrant DBファイルがあるフォルダパス]

    flask-app:
      build: .
      ports:
        - "5000:5000"
      depends_on:
        - qdrant
      volumes:
        - /mnt/y:/mnt/y:rw  # ホストOS側でマウントしているdirをrwできるようマウント
 
    nginx:
      image: nginx:latest
      ports:
        - "80:80"
      depends_on:
        - flask-app
      volumes:
        - /etc/nginx/nginx.conf:/etc/nginx/nginx.conf:ro  # WSL上のNGINX設定ファイルをコンテナ側でマウント

4.画像検索してみた結果

 サーバ上でdocker compose upして、クライアントPCからNginxにブラウザアクセスして動作を確認。

4-1.テキストベースの関連画像検索

 想像以上に意味検索ができている...!
Qdrant Collectionへインサートされている埋め込みベクトル数(画像数)は約1300程度だがちゃんとそれっぽい画像が拾えている印象。starbucksでドトールの画像がtop2に出てきたのは結構驚き。

  • キーワード:「airplain」→ 合致する飛行機の画像のみ表示された
  • キーワード:「industry」→ 合致する工業地帯の画像のみ表示された
  • キーワード:「starbucks」→ 合致するスターバックスorドトールの画像のみ表示された
    • 同業店と解釈できた?スターバックスとドトールの雰囲気が似ていた?
  • キーワード:「bowing man」→ 合致するお辞儀をした男?の画像が何枚か表示された
  • キーワード:「bird by the garbage」→ 合致するカラスの画像のみ表示された

4-2.画像ベースの類似画像検索

 こちらもそれなりに意味検索ができている...!

  • 画像:「白いアヒル」→ 野生のカモの画像のみ表示された
    • データベース上に「白いアヒル」はないので最も近い画像引けていそう
  • 画像:「生のエビ」→ 焼きエビの写ったBBQの画像のみ表示された
    • データベース上に「生のエビ」はないので最も近い画像引けていそう
  • 画像:「秋の兼六園」→ 秋の庭園の画像のみ表示された
    • データベース上に「兼六園」はないので最も近い画像引けていそう

5.最後に

 モデル依存ではあるが思ったよりちゃんと検索エンジンとして機能してくれて驚いた。
また、ベクトルDBは単にベクトルデータを格納するだけなので低容量で済むということも重要なポイントと思った。今回だと8Kサイズ画像(約15MB)を約1300枚を扱ったので画像データとしては計19GBだが、埋め込みベクトルのサイズとしては73MB程度で約1/250と低容量に...!

 今回はGPU搭載PCで実行していたが、キーワード1つなり画像1枚なりをベクトル化する程度ならCPUオンリーでも十分実行できそうなので今度はCPUのみでも動くか?見てみたい。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2