7
10

More than 3 years have passed since last update.

好きな画像を実際に物体検知させてみよう!

Posted at

PyTorch Hubにおいて、学習済みの最先端モデルが簡単にダウンロードできます。
これを初学者でも簡単に使えるようにコードを書きました。GitHubにも公開しています。
コード内の画像URL部分を好きに変えて実行してください!
(URLを貼り付けること。URLの最後が.png等の拡張子で終わっていること。)

開発環境

  • Python : 3.9.0
  • PyTorch : 1.7.1

画像分類

画像に写っているものは何なのかを分類してくれる技術です。
ResNeXtのモデルを使用します。

import os
from PIL import Image
import urllib.request

import matplotlib.pyplot as plt
import torch
from torchvision import transforms

from utils.imagenet_labels import idx2label 

# モデル読み込み
model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
model.eval()

# 画像
url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg" # <- 自由に変更してね
save_dir = "../result/resnext"
os.makedirs(save_dir, exist_ok=True)
filename = f"{save_dir}/{os.path.basename(url)}"
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
input_image = Image.open(filename)

#前処理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) 

#推論
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')
with torch.no_grad():
    output = model(input_batch)

#可視化
idx = torch.nn.functional.softmax(output[0], dim=0).argmax().item()
print(idx2label[idx])
plt.imshow(input_image)
plt.title(idx2label[idx])
plt.savefig(filename)

image.png

物体検知

画像内にうつる物体の場所と分類をしてくれる技術です。
YOLOv5sを使います。

import os
from copy import deepcopy

import torch

# モデル読み込み
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

# 画像
imgs = ['https://ultralytics.com/images/zidane.jpg']  # <- 自由に変更してね!
img_paths = deepcopy(imgs)

# 推論
results = model(imgs)

# 可視化
save_dir = "../result/yolov5"
os.makedirs(save_dir, exist_ok=True)
results.save(save_dir) 
xys = results.pandas().xyxy
for xy, img in zip(xys, img_paths):
    basename = os.path.splitext(os.path.basename(img))[0]
    xy.to_csv(f"{save_dir}/{basename}.csv")
    print(xy)

image.png

7
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
10