PyTorch Hubにおいて、学習済みの最先端モデルが簡単にダウンロードできます。
これを初学者でも簡単に使えるようにコードを書きました。GitHubにも公開しています。
コード内の画像URL部分を好きに変えて実行してください!
(URLを貼り付けること。URLの最後が.png
等の拡張子で終わっていること。)
開発環境
- Python :
3.9.0
- PyTorch :
1.7.1
画像分類
画像に写っているものは何なのかを分類してくれる技術です。
ResNeXtのモデルを使用します。
import os
from PIL import Image
import urllib.request
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from utils.imagenet_labels import idx2label
# モデル読み込み
model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
model.eval()
# 画像
url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg" # <- 自由に変更してね
save_dir = "../result/resnext"
os.makedirs(save_dir, exist_ok=True)
filename = f"{save_dir}/{os.path.basename(url)}"
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
input_image = Image.open(filename)
#前処理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
#推論
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
#可視化
idx = torch.nn.functional.softmax(output[0], dim=0).argmax().item()
print(idx2label[idx])
plt.imshow(input_image)
plt.title(idx2label[idx])
plt.savefig(filename)
物体検知
画像内にうつる物体の場所と分類をしてくれる技術です。
YOLOv5sを使います。
import os
from copy import deepcopy
import torch
# モデル読み込み
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
# 画像
imgs = ['https://ultralytics.com/images/zidane.jpg'] # <- 自由に変更してね!
img_paths = deepcopy(imgs)
# 推論
results = model(imgs)
# 可視化
save_dir = "../result/yolov5"
os.makedirs(save_dir, exist_ok=True)
results.save(save_dir)
xys = results.pandas().xyxy
for xy, img in zip(xys, img_paths):
basename = os.path.splitext(os.path.basename(img))[0]
xy.to_csv(f"{save_dir}/{basename}.csv")
print(xy)