1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

torchvisionの事前トレーニング済みモデルを使う 画像分類

Posted at

torchvisionのトレーニング済みモデルをダウンロードして使う方法です

トレーニングするのは面倒

トレーニング済みの画像モデルが簡単に使える

いろんな種類のモデルが使えます。わりと新しいモデルもあります。

方法

事前トレーニング済みモデルをダウンロードしてインスタンス化。

import torchvision.models as models

regnet_y_400mf = models.regnet_y_400mf(pretrained=True)

これだけでモデルができます。

入力画像をモデルの入力に合わせて前処理します。

224*224にリサイズ、ImageNetデータセットに合わせて正規化します。

import torch
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
 transforms.Resize(224),
 transforms.ToTensor(),
 transforms.Normalize(
 mean=[0.485, 0.456, 0.406],
 std=[0.229, 0.224, 0.225]
 )])

img = Image.open("tabby.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)

画像をモデルのフォワードプロセスにかけます。

out = regnet_y_400mf(batch_t)

出力は(1,1000)の数値です。ImageNet1000クラスに対応する信頼度です。
ソフトマックスで%値にまとめます。
クラスラベルをダウンロードし、トップを表示します。

_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

import urllib
label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
class_labels = urllib.request.urlopen(label_url).read().splitlines()
class_labels = class_labels[1:] # remove the first class which is background

print(class_labels[index[0]], percentage[index[0]].item())

'Egyptian cat' 66.66095733642578

使えるモデルは以下で確認できます。

🐣


フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com

Core MLやARKitを使ったアプリを作っています。
機械学習/AR関連の情報を発信しています。

Twitter
Medium

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?