2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

画像を読み取るAI、Florence2を高速化する

Last updated at Posted at 2024-09-21

Florence2(AIモデル)
画像から意味やテキストを読みとることができる。(2024)。

例えばこの画像にキャプションをつける

'The image shows a small statue sitting atop a moss covered ground, surrounded by a few leaves. The background is slightly blurred, giving the image a dreamy feel.'

「この画像には、苔むした地面の上に小さな像が数枚の葉に囲まれて座っている様子が写っています。背景は少しぼかされており、画像に夢のような雰囲気を与えています。」

もともと高速だが、さらに高速化できる。

speed size
original 2.4 sec 1.08GB
4bit quantized 1.1 sec 267MB

4bit quantized result
'The image shows a small statue sitting atop a moss covered ground, surrounded by leaves. The background is slightly blurred, giving the image a dreamy feel.'

手順

必要なライブラリのインストール

pip -q install flash_attn einops timm quanto accelerate bitsandbytes transformers

普通の初期化

from transformers import AutoProcessor, AutoModelForCausalLM

model_id = 'microsoft/Florence-2-base'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

高速化:モデルを量子化する

from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image

import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
quantization_config_4bit = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)

model_id = 'microsoft/Florence-2-base'
model_4bit = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config_4bit,
    trust_remote_code=True
)

実行

def run_example(model, image, task_prompt, text_input=None, quantized = False, num_beams=3):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    if quantized:
        inputs = inputs.to(dtype=torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=num_beams,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )

    return parsed_answer

import time

task_prompt = '<DETAILED_CAPTION>'
s = time.time()
print(run_example(model, image, task_prompt, quantized=False))
print(time.time() - s)
s = time.time()
print(run_example(model_4bit, image, task_prompt, quantized=True))
print(time.time() - s)

描画

import matplotlib.pyplot as plt
import matplotlib.patches as patches
def plot_bbox(image, data):
   # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    plt.show()

from PIL import Image, ImageDraw, ImageFont
import random
import numpy as np
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
            'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
def draw_polygons(image, prediction, fill_mask=False):
    """
    Draws segmentation masks with polygons on an image.

    Parameters:
    - image_path: Path to the image file.
    - prediction: Dictionary containing 'polygons' and 'labels' keys.
                  'polygons' is a list of lists, each containing vertices of a polygon.
                  'labels' is a list of labels corresponding to each polygon.
    - fill_mask: Boolean indicating whether to fill the polygons with color.
    """
    # Load the image

    draw = ImageDraw.Draw(image)


    # Set up scale factor if needed (use 1 if not scaling)
    scale = 1

    # Iterate over polygons and labels
    for polygons, label in zip(prediction['polygons'], prediction['labels']):
        color = random.choice(colormap)
        fill_color = random.choice(colormap) if fill_mask else None

        for _polygon in polygons:
            _polygon = np.array(_polygon).reshape(-1, 2)
            if len(_polygon) < 3:
                print('Invalid polygon:', _polygon)
                continue

            _polygon = (_polygon * scale).reshape(-1).tolist()

            # Draw the polygon
            if fill_mask:
                draw.polygon(_polygon, outline=color, fill=fill_color)
            else:
                draw.polygon(_polygon, outline=color)

            # Draw the label text
            draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)

    # Save or display the image
    #image.show()  # Display the image
    display(image)

def draw_ocr_bboxes(image, prediction):
    scale = 1
    draw = ImageDraw.Draw(image)
    bboxes, labels = prediction['quad_boxes'], prediction['labels']
    for box, label in zip(bboxes, labels):
        color = random.choice(colormap)
        new_box = (np.array(box) * scale).tolist()
        draw.polygon(new_box, width=3, outline=color)
        draw.text((new_box[0]+8, new_box[1]+2),
                    "{}".format(label),
                    align="right",

                    fill=color)
    display(image)

OCRの高速化

ダウンロード (1).png

OCRタスクでは実行時のオプションを変えると高速になる。
num_beamsを1にする。

量子化では高速にならない。

generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=1, # 3→1
    )
speed
original 28 sec
4bit quantized 20 sec

🐣


フリーランスエンジニアです。
AIについて色々記事を書いていますのでよかったらプロフィールを見てみてください。

もし以下のようなご要望をお持ちでしたらお気軽にご相談ください。
AIサービスを開発したい、ビジネスにAIを組み込んで効率化したい、AIを使ったスマホアプリを開発したい、
ARを使ったアプリケーションを作りたい、スマホアプリを作りたいけどどこに相談したらいいかわからない…

いずれも中間コストを省いたリーズナブルな価格でお請けできます。

お仕事のご相談はこちらまで
rockyshikoku@gmail.com

機械学習やAR技術を使ったアプリケーションを作っています。
機械学習/AR関連の情報を発信しています。

X
Medium
GitHub

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?