2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

がちもとさんAdvent Calendar 2023

Day 25

SAMでマスクした領域をDALLE2でインペイントすーる(Python)

Last updated at Posted at 2023-12-24

はじめに

がちもとさんアドベントカレンダー25日目の記事です。
昨日は、SAM(Segment Anything Model)を用いてセグメンテーションを行いました。
今日は、SAMでマスクした領域に対して、DALLE2を用いてインペイントします。

SAM(Segment Anything Model)やーる(Windows 11、CPU)

OpenAIのDALLE2を用いて透明に塗りつぶした領域を画像生成でインペイントすーる(Python)

開発環境

  • Windows 11 PC
  • Python 3.11

導入

1.OpenAIのAPIキーを取得

2.プログラムを作成

inpaint.py
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

from openai import OpenAI
import requests
client = OpenAI(api_key="<INSERT-YOUR-API-KEY>")

checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=checkpoint)

onnx_model_path = None  # Set to use an already exported model, then skip to the next section.

import warnings

onnx_model_path = "sam_onnx_example.onnx"

onnx_model = SamOnnxModel(sam, return_single_mask=True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )    

onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    # optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path

ort_session = onnxruntime.InferenceSession(onnx_model_path)

sam.to(device='cpu') # sam.to(device='cuda')
predictor = SamPredictor(sam)

# マウスイベントのコールバック関数
def draw_circle(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
        input_point = np.array([[x, y]])
        input_label = np.array([1])

        onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
        onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

        onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
        onnx_has_mask_input = np.zeros(1, dtype=np.float32)

        ort_inputs = {
            "image_embeddings": image_embedding,
            "point_coords": onnx_coord,
            "point_labels": onnx_label,
            "mask_input": onnx_mask_input,
            "has_mask_input": onnx_has_mask_input,
            "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
        }

        masks, _, low_res_logits = ort_session.run(None, ort_inputs)
        masks = masks > predictor.model.mask_threshold

        # マスク画像を生成
        color = np.array([255/255, 255/255, 255/255, 255/255])
        h, w = masks.shape[-2:]
        mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1)
        
        # マスク画像を元の画像のサイズにリサイズ
        mask_image = cv2.resize(mask_image, (image.shape[1], image.shape[0]))
        mask_image = 1 - mask_image
        cv2.imshow("mask", mask_image)
        cv2.waitKey(1)

        # 画像をバイトデータにエンコード
        _, buffer = cv2.imencode(".png", cv2.cvtColor(image, cv2.COLOR_RGB2RGBA))
        image_bytes = buffer.tobytes()

        _, buffer = cv2.imencode(".png", mask_image)
        mask_image_bytes = buffer.tobytes()

        response = client.images.edit(
            model="dall-e-2",
            image=image_bytes,
            mask=mask_image_bytes,
            prompt="Inpaint the object and replace it with the background.", # "Do Inpaint", "Inpaint the missing area in the image.",
            n=1,
            size="1024x1024"
        )
        image_url = response.data[0].url
        print(image_url)

        url_parts = image_url.split('?')
        file_name = url_parts[0].split('/')[-1]

        response = requests.get(image_url)
        if response.status_code == 200:
            with open(file_name, 'wb') as file:
                file.write(response.content)

        image_data = np.frombuffer(response.content, np.uint8)
        edited_image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)

        # 画像を表示
        cv2.imshow('Edited Image', edited_image)

# 画像の読み込み
image = cv2.imread('notebooks/images/dog.jpg')  # 'your_image.jpg'を表示したい画像のパスに置き換えてください。
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()

# OpenCVのウィンドウにコールバック関数をバインド
cv2.namedWindow('image')
cv2.setMouseCallback('image', draw_circle)

while True:
    # 画像を表示
    cv2.imshow('image', image)

    # 'ESC'キーが押されたら終了
    if cv2.waitKey(20) & 0xFF == 27:
        break

# すべてのウィンドウを閉じる
cv2.destroyAllWindows()

実行結果

INPUT OUTPUT
20231224-184759-12d4057b.png 20231224-184818-5328b5be.png

ブルドックをクリックすると、その領域をインペイントしてくれました。
お疲れさまでした。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?