事前準備
ColabをPython3とGPUを使うように設定します。
編集-ノートブックの設定-ランタイムのタイプ-Python3
編集-ノートブックの設定-ハードウェア アクセラレータ-GPU
というふうに設定します。
データ管理をやりやすくするために、ルートディレクトリをHOME定数にします。
import os
HOME = os.getcwd()
SAMと依頼をインストールする
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install -q jupyter_bbox_widget dataclasses-json supervision
重みをダウンロードする
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
import os
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
データを用意する
!mkdir -p {HOME}/data
Webから画像ダウンロード
!wget -q [webからの画像リンクhttps://xxx/xxx.jpg]
ローカルから画像アップロード
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive
ローカルの画像をGoogle Driveにアップロードして、cp
コマンドで/content/data
にコピーして、次のステップに入ります。
SAMをロードする
import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
自動的マスク生成
マスク自動生成するには、SAMモデルに SamAutomaticMaskGenerator
クラスを提供する。
SAMチェックポイントを以下のパスにする。
CUDAでデフォルトモデルを実行するのが推奨される。
mask_generator = SamAutomaticMaskGenerator(sam)
import os
IMAGE_NAME = "test3.png"
IMAGE_PATH = os.path.join(HOME, "data", IMAGE_NAME)
SAMでマスクを生成する
import cv2
import supervision as sv
image_bgr = cv2.imread(IMAGE_PATH)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
sam_result = mask_generator.generate(image_rgb)
出力結果を可視化する
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
detections = sv.Detections.from_sam(sam_result=sam_result)
annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)
sv.plot_images_grid(
images=[image_bgr, annotated_image],
grid_size=(1, 2),
titles=['source image', 'segmented image']
)
セグメント結果出力
masks = [
mask['segmentation']
for mask
in sorted(sam_result, key=lambda x: x['area'], reverse=True)
]
sv.plot_images_grid(
images=masks,
grid_size=(8, 40),
size=(16, 16)
)