PaliGemma 2 Mix モデルで画像解析を実践する
こんにちは、しゅんです!
今回は、Google の最新ビジョン・ランゲージモデル PaliGemma 2 mix を使い、画像から質問応答、物体検出、セグメンテーションタスクを実行する方法を紹介します。
PaliGemma 2 mix は、画像キャプショニング、OCR、画像 Q&A、物体検出、セグメンテーションなど、複数のタスクに対応しており、3B~28B パラメータのモデルから選択できます。
環境構築手順
まずは作業ディレクトリを作成し、環境構築および Big Vision リポジトリのクローンを行います。以下の手順に従ってください。
# 作業ディレクトリの作成とスクリプトファイルの用意
mkdir PaliGemma
cd PaliGemma
touch main_3b_mix_224.py
# 仮想環境の作成と有効化
python3 -m venv .venv
source .venv/bin/activate
# Big Vision リポジトリのクローン(事前セットアップ済みの場合は不要)
git clone --quiet --branch=main --depth=1 https://github.com/google-research/big_vision big_vision_repo
# 必要なライブラリのインストール
pip install flax tensorflow pillow matplotlib
pip install "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
pip install transformers torch
※ Big Vision リポジトリは、後述のセグメンテーションタスク用の再構築関数を利用するために必要です。
また、big_vision_repo
を Python の import パスに追加するコードを後ほど記載しています。
コード例
以下は、PaliGemma 2 mix を用いて "answer"、"detect"、"segment" タスクおよびバッチプロンプトを実行する完全なコード例です。
import io
import requests
import numpy as np
import PIL
from PIL import Image
import re
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image
import torch
import sys
# --- 公式ユーティリティ関数 ---
# Big Vision リポジトリのコードを Python の import パスに追加(事前にクローン済みの前提)
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
def crop_and_resize(image, target_size):
width, height = image.size
source_size = min(image.size)
left = width // 2 - source_size // 2
top = height // 2 - source_size // 2
right, bottom = left + source_size, top + source_size
return image.resize(target_size, box=(left, top, right, bottom))
def read_image(url, target_size):
contents = io.BytesIO(requests.get(url).content)
image = Image.open(contents)
image = crop_and_resize(image, target_size)
image = np.array(image)
if image.shape[2] == 4:
image = image[:, :, :3]
return image
def parse_bbox_and_labels(detokenized_output: str):
matches = re.finditer(
r'<loc(?P<y0>\d{4})><loc(?P<x0>\d{4})><loc(?P<y1>\d{4})><loc(?P<x1>\d{4})> (?P<label>.+?)( ;|$)',
detokenized_output,
)
labels, boxes = [], []
fmt = lambda x: float(x) / 1024.0
for m in matches:
d = m.groupdict()
boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
labels.append(d['label'])
return np.array(boxes), np.array(labels)
def display_boxes(image, boxes, labels, target_image_size):
h, w = target_image_size
fig, ax = plt.subplots()
ax.imshow(image)
for i in range(boxes.shape[0]):
y, x, y2, x2 = boxes[i] * np.array([h, w, h, w])
width = x2 - x
height = y2 - y
rect = patches.Rectangle((x, y), width, height,
linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(x, y, labels[i], color='red', fontsize=12, backgroundcolor='white')
plt.axis("off")
plt.show()
def display_segment_output(image, bounding_box, segment_mask, target_image_size):
full_mask = np.zeros(target_image_size, dtype=np.uint8)
target_width, target_height = target_image_size
for bbox, mask in zip(bounding_box, segment_mask):
y1, x1, y2, x2 = bbox
x1 = int(x1 * target_width)
y1 = int(y1 * target_height)
x2 = int(x2 * target_width)
y2 = int(y2 * target_height)
if not isinstance(mask, np.ndarray):
mask = np.array(mask.tolist())
if mask.ndim == 3:
mask = mask.squeeze(axis=-1)
mask = Image.fromarray(mask)
mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)
mask = np.array(mask)
binary_mask = (mask > 0.5).astype(np.uint8)
full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)
cmap = plt.get_cmap('jet')
colored_mask = cmap(full_mask / 1.0)
colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)
if isinstance(image, Image.Image):
image = np.array(image)
blended_image = image.copy()
mask_indices = full_mask > 0
alpha = 0.5
for c in range(3):
blended_image[:, :, c] = np.where(mask_indices,
(1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],
image[:, :, c])
fig, ax = plt.subplots()
ax.imshow(blended_image)
plt.axis("off")
plt.show()
# セグメンテーション用パース関数(公式コードに準ずる)
import big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval
reconstruct_masks = segeval.get_reconstruct_masks('oi')
def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:
pattern = (
r'<loc(?P<y0>\d{4})><loc(?P<x0>\d{4})><loc(?P<y1>\d{4})><loc(?P<x1>\d{4})>' +
''.join(f'<seg(?P<s{i}>\d{{3}})>' for i in range(16))
)
matches = re.finditer(pattern, detokenized_output)
boxes, segs = [], []
fmt_box = lambda x: float(x) / 1024.0
for m in matches:
d = m.groupdict()
boxes.append([fmt_box(d['y0']), fmt_box(d['x0']),
fmt_box(d['y1']), fmt_box(d['x1'])])
segs.append([int(d[f's{i}']) for i in range(16)])
boxes = np.array(boxes)
segs = np.array(segs)
seg_masks = reconstruct_masks(segs)
return boxes, seg_masks
# --- Transformers を用いた PaliGemma の利用例 ---
# モデルとプロセッサの初期化(使用するモデル: 3B, 224×224)
model_id = "google/paligemma2-3b-mix-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image_transformers = load_image(url)
target_size = (224, 224)
image_np = read_image(url, target_size)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
# --- 1. "answer" タスク ---
prompt = "answer en where is the car standing?\n"
model_inputs = processor(text=prompt, images=image_transformers, return_tensors="pt")
model_inputs = model_inputs.to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded_answer = processor.decode(generation, skip_special_tokens=True)
print("Answer Output:", decoded_answer)
# --- 2. "detect" タスク ---
prompt = "detect car\n"
print("Detect Prompt:", prompt)
model_inputs = processor(text=prompt, images=image_transformers, return_tensors="pt")
model_inputs = model_inputs.to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded_detect = processor.decode(generation, skip_special_tokens=True)
print("Detect Output:", decoded_detect)
boxes, labels = parse_bbox_and_labels(decoded_detect)
print("Parsed Boxes:", boxes)
print("Parsed Labels:", labels)
display_boxes(image_np, boxes, labels, target_image_size=target_size)
# --- 2-1. "segment" タスク ---
prompt = "segment car\n"
print("Segment Prompt:", prompt)
model_inputs = processor(text=prompt, images=image_transformers, return_tensors="pt")
model_inputs = model_inputs.to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded_segment = processor.decode(generation, skip_special_tokens=True)
print("Segment Output:", decoded_segment)
boxes_seg, seg_masks = parse_segments(decoded_segment)
print("Parsed Boxes (Segment):", boxes_seg)
display_segment_output(image_np, boxes_seg, seg_masks, target_image_size=target_size)
# --- 3. バッチプロンプト ---
prompts = [
'answer en where is the car standing?\n',
'answer en what color is the car?\n',
'describe ja\n',
'detect car\n',
]
images = [image_transformers] * len(prompts)
batch_inputs = processor(
text=prompts,
images=images,
return_tensors="pt",
padding=True,
truncation=True
)
batch_inputs = batch_inputs.to(torch.bfloat16).to(model.device)
batch_outputs = model.generate(
**batch_inputs,
max_new_tokens=100,
do_sample=False
)
for i, output in enumerate(batch_outputs):
inp_len = processor(text=prompts[i], images=image_transformers, return_tensors="pt", padding=True, truncation=True)["input_ids"].shape[-1]
decoded = processor.decode(output[inp_len:], skip_special_tokens=True)
print(f"Batch Output {i+1}: {decoded}")
結果
answer タスクの結果
prompt = "answer en where is the car standing?\n"
Answer Output: street
detect タスクの結果
Detect Prompt: detect car
segment タスクの結果
説明の結果(日本に指定可能)
「黄色い 建物 の 前 に 駐車 し た 青い 車 。」
こういう感じで動きます
PaliGemma 2 Mix モデルで画像解析を実践してみた
— SYUN@気は長く、勤めは堅く、色うすく、食細うして、心広かれ (@syun88AI) February 22, 2025
「黄色い 建物 の 前 に 駐車 し た 青い 車 。」とbbox,segmentです! https://t.co/aDH7swmX1G pic.twitter.com/bXJ3Vh9VUK
まとめ
この記事では、PaliGemma 2 Mix を用いて画像から質問応答、物体検出、セグメンテーションタスクを実行する方法を紹介しました。
公式ユーティリティ関数および Big Vision リポジトリの再構築関数を活用することで、出力結果のパースや描画が容易になり、複数タスクをシームレスに扱うことができます。
ぜひ、このコード例を参考にして、あなた自身の画像解析プロジェクトに PaliGemma 2 Mix を組み込んでみてください!
参考リンク
- Introducing PaliGemma 2 mix – Google Developers Blog
- Inference with Keras – PaliGemma Documentation
- Hugging Face Collection: PaliGemma 2 mix
- Google AI Devs on X
以上、PaliGemma 2 Mix を用いた画像解析の実践例でした。
次回の記事もお楽しみに!
最後まで見てくれてありがとうございます。