2
1

More than 1 year has passed since last update.

【IS-Net】隅々まできわだつ輪郭、最新版切り抜きセグメンテーションモデル【推論Colab、CoreMLモデルつき】

Last updated at Posted at 2022-07-20

画像内のオブジェクトを綺麗に切り抜く

画像を入力するだけで、画像内の目立つオブジェクトを綺麗に切り抜いて、セグメンテーションしてくれるモデルがあります。
型抜きのように、隅々まで綺麗に分離できれば、画像合成などさまざまな用途に使えます。

最新の高精度モデルIS-Net

過去には、U2-netという、かなり高精度のセマンティックセグメンテーションモデルがあり、さまざまなアプリケーションに使われています。
同じ作者による高精度モデルが2022年7月に発表されました。
IS-Netです。

bg-removal.gif

(公式リポジトリより引用)

使い方

Python

推論Colab

以下のColabノートブックを実行すると、結果がresultsフォルダに保存されます。

もしくは、リポジトリをクローンして、事前トレーニング済みモデルをダウンロード後、以下のスクリプトで推論できます。

import os
import numpy as np
from skimage import io

import glob

import torch, gc
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import normalize

from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
from basics import  f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import *

net = ISNetDIS() 
net.load_state_dict(torch.load("isnet.pth",map_location="cpu"))
net.eval()
net.cuda()

im = io.imread(image_file)/255
w,h,_ = im.shape
if len(im.shape) < 3:
    im = im[:, :, np.newaxis]
if im.shape[2] == 1:
    im = np.repeat(im, 3, axis=2)
im_tensor = torch.tensor(im, dtype=torch.float32).cuda()
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1,1,1])
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = F.interpolate(im_tensor, size=(1024,1024))

ds_val = net(im_tensor)[0]
im_pred = F.interpolate(ds_val[0], size=(w,h))

im_pred = torch.squeeze(im_pred)
ma = torch.max(im_pred)
mi = torch.min(im_pred)
im_pred = (im_pred-mi)/(ma-mi)
im_result = im_pred.to('cpu').detach().numpy().copy()
io.imsave(os.path.join(result_folder,'output.jpg'), im_result)

Swift

iOSとMacOSではCoreMLに変換したモデルが使えます。

推論リクエストを実行します。

guard let model = try? DIS_1024_2048(configuration: MLModelConfiguration()).model,
      let vnModel = try? VNCoreMLModel(for: model) else {
   fatalError()
}
let request = VNCoreMLRequest(model: vnModel,completionHandler: mlCompletionHandler(request:error:))
request.imageCropAndScaleOption = .scaleFill
 
let handler = VNImageRequestHandler(ciImage: ciImage)
do {
    try handler.perform([request])
    guard let result = request?.results?.first as? VNPixelBufferObservation else {
        fatalError()
    }
    let pixelBuffer = result.pixelBuffer
} catch let error {
    fatalError(error.localizedDescription)
}

🐣


フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com

Core MLやARKitを使ったアプリを作っています。
機械学習/AR関連の情報を発信しています。

Twitter
Medium
GitHub

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1