はじめに
がちもとさんアドベントカレンダー25日目の記事です。
昨日は、SAM(Segment Anything Model)を用いてセグメンテーションを行いました。
今日は、SAMでマスクした領域に対して、DALLE2を用いてインペイントします。
SAM(Segment Anything Model)やーる(Windows 11、CPU)
OpenAIのDALLE2を用いて透明に塗りつぶした領域を画像生成でインペイントすーる(Python)
開発環境
- Windows 11 PC
- Python 3.11
導入
1.OpenAIのAPIキーを取得
2.プログラムを作成
inpaint.py
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
from openai import OpenAI
import requests
client = OpenAI(api_key="<INSERT-YOUR-API-KEY>")
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model_path = None # Set to use an already exported model, then skip to the next section.
import warnings
onnx_model_path = "sam_onnx_example.onnx"
onnx_model = SamOnnxModel(sam, return_single_mask=True)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(onnx_model_path, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
model_input=onnx_model_path,
model_output=onnx_model_quantized_path,
# optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path
ort_session = onnxruntime.InferenceSession(onnx_model_path)
sam.to(device='cpu') # sam.to(device='cuda')
predictor = SamPredictor(sam)
# マウスイベントのコールバック関数
def draw_circle(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
input_point = np.array([[x, y]])
input_label = np.array([1])
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}
masks, _, low_res_logits = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
# マスク画像を生成
color = np.array([255/255, 255/255, 255/255, 255/255])
h, w = masks.shape[-2:]
mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1)
# マスク画像を元の画像のサイズにリサイズ
mask_image = cv2.resize(mask_image, (image.shape[1], image.shape[0]))
mask_image = 1 - mask_image
cv2.imshow("mask", mask_image)
cv2.waitKey(1)
# 画像をバイトデータにエンコード
_, buffer = cv2.imencode(".png", cv2.cvtColor(image, cv2.COLOR_RGB2RGBA))
image_bytes = buffer.tobytes()
_, buffer = cv2.imencode(".png", mask_image)
mask_image_bytes = buffer.tobytes()
response = client.images.edit(
model="dall-e-2",
image=image_bytes,
mask=mask_image_bytes,
prompt="Inpaint the object and replace it with the background.", # "Do Inpaint", "Inpaint the missing area in the image.",
n=1,
size="1024x1024"
)
image_url = response.data[0].url
print(image_url)
url_parts = image_url.split('?')
file_name = url_parts[0].split('/')[-1]
response = requests.get(image_url)
if response.status_code == 200:
with open(file_name, 'wb') as file:
file.write(response.content)
image_data = np.frombuffer(response.content, np.uint8)
edited_image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
# 画像を表示
cv2.imshow('Edited Image', edited_image)
# 画像の読み込み
image = cv2.imread('notebooks/images/dog.jpg') # 'your_image.jpg'を表示したい画像のパスに置き換えてください。
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
# OpenCVのウィンドウにコールバック関数をバインド
cv2.namedWindow('image')
cv2.setMouseCallback('image', draw_circle)
while True:
# 画像を表示
cv2.imshow('image', image)
# 'ESC'キーが押されたら終了
if cv2.waitKey(20) & 0xFF == 27:
break
# すべてのウィンドウを閉じる
cv2.destroyAllWindows()
実行結果
INPUT | OUTPUT |
---|---|
ブルドックをクリックすると、その領域をインペイントしてくれました。
お疲れさまでした。