5
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

簡単!学習済みモデルで物体検出を試す!

Posted at

こんにちは、kamikawaです
今回はpytorchの学習済みモデルで物体検出をする方法を解説します

実行環境はGoogle Colaboratory (Colab)です。

この記事の対象読者

  • 画像AIを試してみたい!
  • 物体検出を試してみたい!

物体検出

画像認識分野(CV分野)において注目を集める物体検出を試していきます。
物体検出は画像の  ”どこに”   ”何が”  あるかを調べるタスクと言えます。
よくAIを紹介しているニュースなどで、映っている車や人を四角い枠で囲んでいる画像を見たことがあると思います。そのイメージです。

今回使用するモデル

今回は物体検出の代表的なネットワークである

  • Faster R-CNN ・・・高速なモデル
  • Mask R-CNN ・・・正確なモデル(セグメンテーションもできます)

を使用します

これらのモデルを自分で学習させると時間がかかりますが、pytorchのtorchvisionで学習済みモデルが使用できるので今回はそれを用いて物体検知を試していきます。

torchvisionが提供している訓練済みのモデルは公式サイトから確認できます。

コード

個人的にcv2をよく使っているのでcv2を使っています(主に描画の処理の部分)。
そのためやや冗長なコードになっています。実は、PILだけで完結することもできます。


#画像中の物体検出
#必要なライブラリのimport
#Colabには標準でインストールされているのでinstallの必要はなし

import torch #深層学習用ライブラリの一つであるpytorchをimport
import PIL #画像処理用ライブラリの一つPILをimport
from PIL import Image #PILからImageをimport
import torchvision #pytorchの画像処理用ライブラリをimport
from torchvision import transforms #画像処理用ライブラリからtransformsをimport
import cv2 #画像処理用ライブラリ
import matplotlib.pyplot as plt #グラフや画像を表示するためのライブラリをimport

#画像を読み込む

frame_raw = cv2.imread('画像のpath')


#PILの画像配列にすればモデルに入力しやすいのでPILの配列に変換
#cv2はBGRの順番だがPILはRGBの順番

frame = cv2.cvtColor(frame_raw,cv2.COLOR_BGR2RGB)#BGRの順番になっているのをRGBに並べ替える

#numpy→PIL
image = Image.fromarray(frame) #numpyのarrayから変換

#モデルをダウンロード

#Faster R-CNNはこちら
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

#Mask R-CNNを試したい場合は、一行下のコードのコメントアウト(#)を外す
#model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

#いよいよ推論!!
with torch.no_grad(): #pytorchの自動微分機能を切る
  #推論に使うデバイスを選択(GPUを使用する場合は torch.device('cuda'))
  device = torch.device('cpu') 
  transform = transforms.Compose([transforms.ToTensor()]) #PILをtensorに変換にするためのインスタンスを用意
  inputs = transform(image) #PIL→tensor
  inputs = inputs.unsqueeze(0).to(device)#デバイスに入力
  model.eval() #モデルを推論モードに切り替え
  outputs = model(inputs) #モデルに推論させて結果を受け取る
  for i in range(len(outputs[0]['boxes'])):#推論結果からバウンディングボックス(BB)の座標を取り出す
    x0 = outputs[0]['boxes'][i][0] #BBの左上の点のx座標
    y0 = outputs[0]['boxes'][i][1] #BBの左上の点のy座標
    x1 = outputs[0]['boxes'][i][2] #BBの右下の点のx座標
    y1 = outputs[0]['boxes'][i][3] #BBの右下の点のy座標
    #confidence(モデルがその推論にどのくらい自信があるか)が0.7以上だったら
    if outputs[0]['scores'][i]>=0.7:
      bbox = cv2.rectangle(frame_raw,(x0,y0),(x1,y1),(0,0,300),3,4)#BBを表示
  #結果を表示
  plt.figure(figsize = (12,9)) #表示する画像のサイズを決定
  plt.imshow(cv2.cvtColor(bbox, cv2.COLOR_BGR2RGB) )#cv2はBGRだがpltはRGB
  plt.axis("off")#グラフの目盛りが入るのを防ぐ
  plt.show() #結果の画像表示させる

結果

実際に実行した結果を紹介します
Unknown-12.png

船や車などが四角で囲まれているのがわかると思います

また、先頭にクラス名のリストを追加し

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
#以下略

線で囲まれた部分のコードを追加すると

#これより上は略
  for i in range(len(outputs[0]['boxes'])):
    x0 = outputs[0]['boxes'][i][0] 
    y0 = outputs[0]['boxes'][i][1]
    x1 = outputs[0]['boxes'][i][2]
    y1 = outputs[0]['boxes'][i][3]
    if outputs[0]['scores'][i]>=0.9:
      bbox = cv2.rectangle(frame_raw,(x0,y0),(x1,y1),(0,0,255),3,4)
#---------------------下のコードを追加--------------------------------------------------------
      class_num = outputs[0]['labels'][i]
      class_name = COCO_INSTANCE_CATEGORY_NAMES[class_num]
      conf = float(outputs[0]['scores'][i])
      conf = round(conf,4)
      label = class_name + str(conf)
      bbox = cv2.putText(bbox,label,(x0,y0),cv2.FONT_HERSHEY_COMPLEX,2.1,(0,128,0),2 ) 
#------------------------------------------------------------------------------------------------
  plt.figure(figsize = (12,9))
  plt.imshow(cv2.cvtColor(bbox, cv2.COLOR_BGR2RGB))
  plt.axis("off")
  plt.show()

検出した物体のクラスとコンフィデンスが表示されるようになります

Unknown-16.png

コード全文

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

#画像中の物体検出
#必要なライブラリのimport
#Colabには標準でインストールされているのでinstallの必要はなし

import torch #深層学習用ライブラリの一つであるpytorchをimport
import PIL #画像処理用ライブラリの一つPILをimport
from PIL import Image #PILからImageをimport
import torchvision #pytorchの画像処理用ライブラリをimport
from torchvision import transforms #画像処理用ライブラリからtransformsをimport
import cv2 #画像処理用ライブラリ
import matplotlib.pyplot as plt #グラフや画像を表示するためのライブラリをimport

#画像を読み込む

frame_raw = cv2.imread('画像のpath')


#PILの画像配列にすればモデルに入力しやすいのでPILの配列に変換
#cv2はBGRの順番だがPILはRGBの順番

frame = cv2.cvtColor(frame_raw,cv2.COLOR_BGR2RGB)#BGRの順番になっているのをRGBに並べ替える

#numpy→PIL
image = Image.fromarray(frame) #numpyのarrayから変換

#モデルをダウンロード

#Faster R-CNNはこちら
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

#Mask R-CNNを試したい場合は、一行下のコードのコメントアウト(#)を外す
#model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

#いよいよ推論!!
with torch.no_grad(): #pytorchの自動微分機能を切る
  #推論に使うデバイスを選択(GPUを使用する場合は torch.device('cuda'))
  device = torch.device('cpu') 
  transform = transforms.Compose([transforms.ToTensor()]) #PILをtensorに変換にするためのインスタンスを用意
  inputs = transform(image) #PIL→tensor
  inputs = inputs.unsqueeze(0).to(device)#デバイスに入力
  model.eval() #モデルを推論モードに切り替え
  outputs = model(inputs) #モデルに推論させて結果を受け取る
  for i in range(len(outputs[0]['boxes'])):
    x0 = outputs[0]['boxes'][i][0] 
    y0 = outputs[0]['boxes'][i][1]
    x1 = outputs[0]['boxes'][i][2]
    y1 = outputs[0]['boxes'][i][3]
    if outputs[0]['scores'][i]>=0.9:
      bbox = cv2.rectangle(frame_raw,(x0,y0),(x1,y1),(0,0,255),3,4)
      class_num = outputs[0]['labels'][i]
      class_name = COCO_INSTANCE_CATEGORY_NAMES[class_num]
      conf = float(outputs[0]['scores'][i])
      conf = round(conf,4)
      label = class_name + str(conf)
      bbox = cv2.putText(bbox,label,(x0,y0),cv2.FONT_HERSHEY_COMPLEX,2.1,(0,128,0),2 ) 
  plt.figure(figsize = (12,9))
  plt.imshow(cv2.cvtColor(bbox, cv2.COLOR_BGR2RGB))
  plt.axis("off")
  plt.show()

セッションがクラッシュした場合

ランタイムのタイプをGPUにしてみましょう(画像ファイルはもう一度アップロードし直さなければなりません)
GPUランタイムへの切り替え方はこちらの記事を参考にしてください

コードの

device = torch.device('cpu') 

device = torch.device('cuda')

に変更してください

最後に

最後まで記事をご覧いただきありがとうございました
実行環境はColabratoryですので環境に合わせてコードは変更してください(ファイルのパス等)
Colabの使い方はこちらの記事を参考にしてください

5
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?