2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LlamaIndexとGeminiを用いた半構造化画像の検索

Posted at

こちらをウォークスルーします。結構苦戦しました。

このノートブックでは画像に対する半構造化情報検索の実行方法を説明します。

これらの構造化アウトプットをベクトルデータベースでインデックス化することができます。そして、auto-retrievalを用いたセマンティック検索 + メタデータフィルターの完全な機能を活用します。これによって、このデータに対して構造化された質問とセマンティックな質問の両方を尋ねることができます!

ライブラリのインストール

%pip install llama-index-multi-modal-llms-gemini
%pip install llama-index-vector-stores-qdrant
%pip install llama-index-embeddings-gemini
%pip install llama-index-llms-gemini
%pip install pydantic==1.10.11

%pip install pydantic==1.10.11を実行しないと、こちらのエラーに遭遇します。

%pip install llama-index 'google-generativeai>=0.3.0' matplotlib qdrant_client
dbutils.library.restartPython()

セットアップ

Google APIキーの取得をします。また、当該プロジェクトでGenerative Language APIを有効化します。
Screenshot 2024-02-27 at 12.55.27.png

import os

GOOGLE_API_KEY = "<Google APIキー>"  # GOOGLE APIキーをこちらに設定
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

画像のダウンロード

ファイルをVolumeにアップロード

以下からダウンロードしたarchive.zipをVolumeにアップロードします。

解凍します。パスは適宜変更してください。

%sh
unzip /Volumes/takaakiyayoi_catalog/llama_index/data/archive.zip -d /Volumes/takaakiyayoi_catalog/llama_index/data/

画像ファイルの取得

from pathlib import Path
import random
from typing import Optional
def get_image_files(
    dir_path, sample: Optional[int] = 10, shuffle: bool = False
):
    dir_path = Path(dir_path)
    image_paths = []
    for image_path in dir_path.glob("*.jpg"):
        image_paths.append(image_path)

    random.shuffle(image_paths)
    if sample:
        return image_paths[:sample]
    else:
        return image_paths

Google API呼び出しでエラーが結構起きるので、意図的にサンプル数は減らしています。

image_files = get_image_files("/Volumes/takaakiyayoi_catalog/llama_index/data/SROIE2019/test/img", sample=5) # 100

構造化アウトプットを抽出するためにGeminiを使用

ReceiptInfo pydantic クラスの定義

from pydantic import BaseModel, Field


class ReceiptInfo(BaseModel):
    company: str = Field(..., description="Company name")
    date: str = Field(..., description="Date field in DD/MM/YYYY format")
    address: str = Field(..., description="Address")
    total: float = Field(..., description="total amount")
    currency: str = Field(
        ..., description="Currency of the country (in abbreviations)"
    )
    summary: str = Field(
        ...,
        description="Extracted text summary of the receipt, including items purchased, the type of store, the location, and any other notable salient features (what does the purchase seem to be for?).",
    )

pydantic_gemini 関数の定義

from llama_index.multi_modal_llms.gemini import GeminiMultiModal
from llama_index.core.program import MultiModalLLMCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser

prompt_template_str = """\
    Can you summarize the image and return a response \
    with the following JSON format: \
"""


async def pydantic_gemini(output_class, image_documents, prompt_template_str):
    gemini_llm = GeminiMultiModal(
        api_key=GOOGLE_API_KEY, model_name="models/gemini-pro-vision"
    )

    llm_program = MultiModalLLMCompletionProgram.from_defaults(
        output_parser=PydanticOutputParser(output_class),
        image_documents=image_documents,
        prompt_template_str=prompt_template_str,
        multi_modal_llm=gemini_llm,
        verbose=True,
    )

    response = await llm_program.acall()
    return response

画像に対して処理を実行

from llama_index.core import SimpleDirectoryReader
from llama_index.core.async_utils import run_jobs


async def aprocess_image_file(image_file):
    # 1つのファイルをロードすべき
    print(f"Image file: {image_file}")
    img_docs = SimpleDirectoryReader(input_files=[image_file]).load_data()
    output = await pydantic_gemini(ReceiptInfo, img_docs, prompt_template_str)
    return output


async def aprocess_image_files(image_files):
    """画像ファイルのメタデータを処理"""

    new_docs = []
    tasks = []
    for image_file in image_files:
        task = aprocess_image_file(image_file)
        tasks.append(task)

    outputs = await run_jobs(tasks, show_progress=True, workers=5)
    return outputs
outputs = await aprocess_image_files(image_files)
outputs[4]
ReceiptInfo(company='UNIHAAKA INTERNATIONAL SDN BHD', date='07 Mar 2018 18:22', address='12, Jalan Tampoi 7/4 Kawasan Perindustrian\nTampoi, 81200 Johor Bahru, Johor', total=8.2, currency='MYR', summary='The receipt is from a restaurant called BAR WANG RICE @ ERMAS JAYA. The total amount is 8.20 MYR. The items purchased include 1 meat and 3 vege.')

構造化表現をTextNodeオブジェクトに変換

from llama_index.core.schema import TextNode
from typing import List


def get_nodes_from_objs(
    objs: List[ReceiptInfo], image_files: List[str]
) -> TextNode:
    """オブジェクトからノードを取得"""
    nodes = []
    for image_file, obj in zip(image_files, objs):
        node = TextNode(
            text=obj.summary,
            metadata={
                "company": obj.company,
                "date": obj.date,
                "address": obj.address,
                "total": obj.total,
                "currency": obj.currency,
                "image_file": str(image_file),
            },
            excluded_embed_metadata_keys=["image_file"],
            excluded_llm_metadata_keys=["image_file"],
        )
        #print(node)

        nodes.append(node)

    return nodes
nodes = get_nodes_from_objs(outputs, image_files)
print(nodes[0].get_content(metadata_mode="all"))

レシートの情報が取得できています。

company: BHPetrol
date: 03/04/2018
address: LOT PTD 101051
Jalan Permas 10/10
81750 Masai, Johor
total: 50.0
currency: RM
image_file: /Volumes/takaakiyayoi_catalog/llama_index/data/SROIE2019/test/img/X51006414512.jpg

Purchase of petrol.

ベクトルストアでこれらのノードのインデックスを作成

import qdrant_client
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.llms.gemini import Gemini
from llama_index.core import Settings

# ローカルのQdrantベクトルストアを作成
client = qdrant_client.QdrantClient(path="qdrant_gemini")

vector_store = QdrantVectorStore(client=client, collection_name="collection")

エラーになったので、少し変更しています。

# グローバル設定
Settings.embed_model = GeminiEmbedding(
    model_name="models/embedding-001", api_key=GOOGLE_API_KEY
)
Settings.llm = Gemini(api_key=GOOGLE_API_KEY)

storage_context = StorageContext.from_defaults(vector_store=vector_store)

index = VectorStoreIndex(
    nodes=nodes,
    storage_context=storage_context,
)

Auto-Retrieverの定義

こちらはMetadataInfoやVectorStoreInfoのモジュールが見つからないエラーが出たので別途定義しています。

class MetadataInfo(BaseModel):
    """Information about a metadata filter supported by a vector store.

    Currently only used by VectorIndexAutoRetriever.
    """

    name: str
    type: str
    description: str

class VectorStoreInfo(BaseModel):
    """Information about a vector store (content and supported metadata filters).

    Currently only used by VectorIndexAutoRetriever.
    """

    metadata_info: List[MetadataInfo]
    content_info: str
vector_store_info = VectorStoreInfo(
    content_info="Receipts",
    metadata_info=[
        MetadataInfo(
            name="company",
            description="The name of the store",
            type="string",
        ),
        MetadataInfo(
            name="address",
            description="The address of the store",
            type="string",
        ),
        MetadataInfo(
            name="date",
            description="The date of the purchase (in DD/MM/YYYY format)",
            type="string",
        ),
        MetadataInfo(
            name="total",
            description="The final amount",
            type="float",
        ),
        MetadataInfo(
            name="currency",
            description="The currency of the country the purchase was made (abbreviation)",
            type="string",
        ),
    ],
)
from llama_index.core.retrievers import VectorIndexAutoRetriever

retriever = VectorIndexAutoRetriever(
    index,
    vector_store_info=vector_store_info,
    similarity_top_k=2,
    empty_query_top_k=10,  # メタデータのフィルターが指定された場合、こちらが制限になります
    verbose=True,
)
# PILからImageをインポート
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from IPython.display import Image


def display_response(nodes: List[TextNode]):
    """Display response."""
    for node in nodes:
        print(node.get_content(metadata_mode="all"))
        # img = Image.open(open(node.metadata["image_file"], 'rb'))
        display(Image(filename=node.metadata["image_file"], width=200))

いくつかクエリーを実行

合計が25未満の麺の注文を教えてください

nodes = retriever.retrieve(
    "Tell me about some restaurant orders of noodles with total < 25"
)
display_response(nodes)

おおー、結果が返ってきます。
Screenshot 2024-02-27 at 13.11.08.png

雑貨の購入をいくつか教えてください

nodes = retriever.retrieve("Tell me about some grocery purchases")
display_response(nodes)

Screenshot 2024-02-27 at 13.16.19.png

レシート画像に対して意味に関する検索と定量的なフィルターの両方が適用できています。面白い。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?