0
0

Google ColabでSAM(Segment Anything Model)モデルを使ってみる

Posted at

事前準備

ColabをPython3とGPUを使うように設定します。
編集-ノートブックの設定-ランタイムのタイプ-Python3
編集-ノートブックの設定-ハードウェア アクセラレータ-GPU
というふうに設定します。
データ管理をやりやすくするために、ルートディレクトリをHOME定数にします。

import os
HOME = os.getcwd()

SAMと依頼をインストールする

!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install -q jupyter_bbox_widget dataclasses-json supervision

重みをダウンロードする

!mkdir -p {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
import os
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

データを用意する

!mkdir -p {HOME}/data

Webから画像ダウンロード

!wget -q [webからの画像リンクhttps://xxx/xxx.jpg]

ローカルから画像アップロード

from google.colab import drive 
drive.mount('/content/drive') 
%cd /content/drive/MyDrive

ローカルの画像をGoogle Driveにアップロードして、cpコマンドで/content/dataにコピーして、次のステップに入ります。

SAMをロードする

import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

自動的マスク生成

マスク自動生成するには、SAMモデルに SamAutomaticMaskGenerator クラスを提供する。
SAMチェックポイントを以下のパスにする。
CUDAでデフォルトモデルを実行するのが推奨される。

mask_generator = SamAutomaticMaskGenerator(sam)

import os
IMAGE_NAME = "test3.png"
IMAGE_PATH = os.path.join(HOME, "data", IMAGE_NAME)

SAMでマスクを生成する

import cv2
import supervision as sv
image_bgr = cv2.imread(IMAGE_PATH)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sam_result = mask_generator.generate(image_rgb)

出力結果を可視化する

mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
  
detections = sv.Detections.from_sam(sam_result=sam_result)
  
annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

セグメント結果出力

masks = [
    mask['segmentation']
    for mask
    in sorted(sam_result, key=lambda x: x['area'], reverse=True)
]

sv.plot_images_grid(
    images=masks,
    grid_size=(8, 40),
    size=(16, 16)
)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0