0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SAMを用いた画像セグメンテーションの試み

Posted at

今回はFacebook AI Researchが開発した「Segment Anything Model (SAM)」を用いて、画像から自動的にセグメントを生成する方法を共有します。

SAMは、任意の画像内のオブジェクトを検出しその領域をマスクとして抽出する高性能なモデルです。
https://segment-anything.com/

使用したコードの概要

以下では、画像セグメンテーションを行うためのコードを簡単に解説します。

必要なライブラリ
まず、以下のライブラリをインストールして準備します:

qiita.rb
pip install torch matplotlib opencv-python pillow segment-anything

モデルと設定の準備

SAMのモデルを読み込み、自動的にマスクを生成する設定を行います。

qiita.rb
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

sam_checkpoint = "/path/to/sam_vit_h_4b8939.pth"  # モデルのチェックポイント
model_type = "vit_h"  # モデルのタイプ

# モデルをCPUにロード
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device="cpu")

# マスク生成器の設定
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.98,
    stability_score_thresh=0.96,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=500,
)

画像の読み込みとマスク生成

対象の画像を読み込み、マスクを生成します。

qiita.rb
import cv2
import matplotlib.pyplot as plt

image_path = "/path/to/image.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

masks = mask_generator.generate(image)
print(f"Generated {len(masks)} masks.")

マスクの数や各マスクの詳細な情報を出力することができます。

qiita.rb
for idx, mask in enumerate(masks):
    print(f"**** Mask Index: {idx} ****")
    print(f"- Area            : {mask['area']} pixels")
    print(f"- Bounding Box    : {mask['bbox']} [x, y, width, height]")
    print(f"- Predicted IoU   : {mask['predicted_iou']:.3f}")
    print(f"- Stability Score : {mask['stability_score']:.3f}")

マスクの可視化

生成されたマスクを可視化して、画像上で確認します。

qiita.rb
def visualize_masks(image, masks, num_masks=5):
    import numpy as np
    sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    num_masks = min(num_masks, len(sorted_masks))

    plt.figure(figsize=(18, 12))
    for i in range(num_masks):
        plt.subplot(1, num_masks, i + 1)
        plt.imshow(image)
        mask = sorted_masks[i]['segmentation']
        plt.imshow(mask, alpha=0.5, cmap='jet')
        plt.title(f'Mask {i+1}\nArea: {sorted_masks[i]["area"]}')
        plt.axis('off')
    plt.show()

visualize_masks(image, masks, num_masks=5)

実行結果

画像のオリジナルとマスクされた領域を比較し、どのようにセグメントが生成されたか確認できます。

調整可能なパラメータ

1. points_per_side: 探索点の密度を変更します。値を増やすと精度が向上しますが、計算時間も長くなります。
2. pred_iou_thresh: マスク選択の精度を設定します。値を高くすると高品質なマスクのみを選択します。
3. stability_score_thresh: マスクの安定性を調整します。値を低くすると検出領域が広がります。
4. min_mask_region_area: 最小マスクサイズを設定します。小さい値にすると小さなオブジェクトも検出可能です。

SAMを使用することで簡単に画像からオブジェクトを検出しセグメント化することが可能です!みなさんもぜひ異なる画像やパラメータで試してみてください!

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?