画像内のオブジェクトを綺麗に切り抜く
画像を入力するだけで、画像内の目立つオブジェクトを綺麗に切り抜いて、セグメンテーションしてくれるモデルがあります。
型抜きのように、隅々まで綺麗に分離できれば、画像合成などさまざまな用途に使えます。
最新の高精度モデルIS-Net
過去には、U2-netという、かなり高精度のセマンティックセグメンテーションモデルがあり、さまざまなアプリケーションに使われています。
同じ作者による高精度モデルが2022年7月に発表されました。
IS-Netです。
(公式リポジトリより引用)
使い方
Python
推論Colab
以下のColabノートブックを実行すると、結果がresultsフォルダに保存されます。
もしくは、リポジトリをクローンして、事前トレーニング済みモデルをダウンロード後、以下のスクリプトで推論できます。
import os
import numpy as np
from skimage import io
import glob
import torch, gc
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
from models import *
net = ISNetDIS()
net.load_state_dict(torch.load("isnet.pth",map_location="cpu"))
net.eval()
net.cuda()
im = io.imread(image_file)/255
w,h,_ = im.shape
if len(im.shape) < 3:
im = im[:, :, np.newaxis]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
im_tensor = torch.tensor(im, dtype=torch.float32).cuda()
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1,1,1])
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = F.interpolate(im_tensor, size=(1024,1024))
ds_val = net(im_tensor)[0]
im_pred = F.interpolate(ds_val[0], size=(w,h))
im_pred = torch.squeeze(im_pred)
ma = torch.max(im_pred)
mi = torch.min(im_pred)
im_pred = (im_pred-mi)/(ma-mi)
im_result = im_pred.to('cpu').detach().numpy().copy()
io.imsave(os.path.join(result_folder,'output.jpg'), im_result)
Swift
iOSとMacOSではCoreMLに変換したモデルが使えます。
推論リクエストを実行します。
guard let model = try? DIS_1024_2048(configuration: MLModelConfiguration()).model,
let vnModel = try? VNCoreMLModel(for: model) else {
fatalError()
}
let request = VNCoreMLRequest(model: vnModel,completionHandler: mlCompletionHandler(request:error:))
request.imageCropAndScaleOption = .scaleFill
let handler = VNImageRequestHandler(ciImage: ciImage)
do {
try handler.perform([request])
guard let result = request?.results?.first as? VNPixelBufferObservation else {
fatalError()
}
let pixelBuffer = result.pixelBuffer
} catch let error {
fatalError(error.localizedDescription)
}
🐣
フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com
Core MLやARKitを使ったアプリを作っています。
機械学習/AR関連の情報を発信しています。