LoginSignup
8
6

More than 3 years have passed since last update.

とりあえずCenterNetを試してみた

Last updated at Posted at 2019-12-06

大っきくてうまいんです

公式
https://github.com/xingyizhou/CenterNet

日本語のスライド
https://www.slideshare.net/DeepLearningJP2016/dlobjects-as-points

手順参考
https://qiita.com/sonodaatom/items/fe148a177b1d1b40c34d

CenterNetの特徴

YOLOv3やM2Detよりも速くて、精度がいいらしいです。

Objects as Pointsという中心点を推測するやり方で、NMS(Non-Maximum Suppression)をすっ飛ばせるらしいです。

スクリーンショット 2019-12-05 19.23.25.png

GPUが必要、ならGoogleColaboratoryで

まずマウントします

from google.colab import drive
drive.mount('/content/gdrive')

僕はGPUを確認します。ワクワクするからです。

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

-> physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"

GoogleCoraboratoryで僕がみたことあるのは、P100、T4、K80、P4です。T4、P100が当たると嬉しいです。

マイドライブ配下にcenternetというディレクトリを作って、作業します。

%cd /content/gdrive/My\ Drive/centernet
!git clone https://github.com/xingyizhou/CenterNet.git
%cd CenterNet

!wget https://developer.nvidia.com/compute/cuda/9.0/Prod/local_installers/cuda-repo-ubuntu1604-9-0-local_9.0.176-1_amd64-deb
!dpkg --install cuda-repo-ubuntu1604-9-0-local_9.0.176-1_amd64-deb
!sudo apt-key add /var/cuda-repo-9-0-local/7fa2af80.pub
!sudo apt-get update
!sudo apt-get install cuda-9-0
!nvcc --version

!pip install -U torch==0.4.1 torchvision==0.2.2

!pip install -r requirements.txt

%cd /content/gdrive/My\ Drive/centernet/CenterNet/src/lib/models/networks/DCNv2/src/cuda

!apt install gcc-5 g++-5 -y

!ln -sf /usr/bin/gcc-5 /usr/bin/gcc 
!ln -sf /usr/bin/g++-5 /usr/bin/g++

!nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC
!nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC

!nvcc -c -o dcn_v2_psroi_pooling_cuda.cu.o dcn_v2_psroi_pooling_cuda.cu -x cu -Xcompiler -fPIC
!nvcc -c -o dcn_v2_psroi_pooling_cuda_double.cu.o dcn_v2_psroi_pooling_cuda_double.cu -x cu -Xcompiler -fPIC

%cd /content/gdrive/My Drive/centernet/CenterNet/src/lib/models/networks/DCNv2

!python build.py
!python build_double.py

%cd /content/gdrive/My Drive/centernet/CenterNet/src/lib/external

!make

%cd ../../

https://github.com/xingyizhou/CenterNet/blob/master/readme/MODEL_ZOO.md
ここからモデル(ctdet_coco_dla_2x.pth)をダウンロードして、下記におきます。
/content/gdrive/My\ Drive/centernet/CenterNet/models/

import sys
import cv2
import matplotlib.pyplot as plt
from detectors.detector_factory import detector_factory
from opts import opts
%matplotlib inline

MODEL_PATH = "../models/ctdet_coco_dla_2x.pth"
TASK = 'ctdet' # or 'multi_pose' for human pose estimation
opt = opts().init('{} --load_model {}'.format(TASK, MODEL_PATH).split(' '))
detector = detector_factory[opt.task](opt)

coco_names = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush',] 

https://github.com/mozilla/Fira/blob/master/otf/FiraMono-Medium.otf
フォントファイルを/content/gdrive/My Drive/centernetとかにおきます。


from PIL import Image, ImageFont, ImageDraw
import numpy as np
import colorsys
from pylab import rcParams

rcParams['figure.figsize'] = 10,10

# Generate colors for drawing bounding boxes.
hsv_tuples = [(x / len(coco_names), 1., 1.)
              for x in range(len(coco_names))]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(
    map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
        colors))
np.random.seed(10101)  # Fixed seed for consistent colors across runs.
np.random.shuffle(colors) 

def write_rect(image, box, cl, score):
    font = ImageFont.truetype(font='/content/gdrive/My Drive/centernet/FiraMono-Bold.otf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
    # font = 1
    thickness = (image.size[0] + image.size[1]) // 300


    label = '{} {:.2f}'.format(coco_names[cl], score)
    draw = ImageDraw.Draw(image)
    label_size = draw.textsize(label, font)

    left, top, right, bottom = box
    top = max(0, np.floor(top + 0.5).astype('int32'))
    left = max(0, np.floor(left + 0.5).astype('int32'))
    bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
    right = min(image.size[0], np.floor(right + 0.5).astype('int32'))

    if top - label_size[1] >= 0:
        text_origin = np.array([left, top - label_size[1]])
    else:
        text_origin = np.array([left, top + 1])

    # My kingdom for a good redistributable image drawing library.
    for i in range(thickness):
        draw.rectangle(
            [left + i, top + i, right - i, bottom - i],
            outline=colors[cl])
    draw.rectangle(
        [tuple(text_origin), tuple(text_origin + label_size)],
        fill=colors[cl])
    draw.text(text_origin, label, fill=(0, 0, 0), font=font)

画像の物体検出してみます。

img = "/content/gdrive/My Drive/centernet/CenterNet/images/33823288584_1d21cf0a26_k.jpg"
rets = detector.run(img)['results']
img = Image.open(img)
print(img)

for i in range(len(rets)):
    ret = rets.get(i+1)
    if ret.shape[0]==0:
        continue
    ret = ret[ret[:,4]>0.6]
    for box in ret:
        write_rect(img, box[:4], i, box[4])

plt.imshow(img)

スクリーンショット 2019-12-05 20.00.10.png

本当はOpenPoseをやりたかったのですが、よくわからなかったので一旦物体検出で
動画を吐き出してみます。

def cv2pil(image):
    new_image = image.copy()
    if new_image.ndim == 2:
        pass
    elif new_image.shape[2] == 3:
        new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
    elif new_image.shape[2] == 4:
        new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
    new_image = Image.fromarray(new_image)
    return new_image

def pil2cv(image):
    new_image = np.array(image, dtype=np.uint8)
    if new_image.ndim == 2:
        pass
    elif new_image.shape[2] == 3:
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
    elif new_image.shape[2] == 4:
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
    return new_image
fourcc = cv2.VideoWriter_fourcc('m','p','4','v')
_video = cv2.VideoWriter('footsal.mp4', fourcc, 10.0, (568, 320))

適当な動画をgdriveにあげて、解析結果を動画にします。


video = "/content/gdrive/My Drive/centernet/CenterNet/images/IMG_6880.mov"
cam = cv2.VideoCapture(video)
counter = 0
while cam.isOpened():
    counter += 1
    _, img = cam.read()
    if counter % 3 != 1:
          continue
    if _:
      rets = detector.run(img)['results']
      img = cv2pil(img)

      for i in range(len(rets)):
          ret = rets.get(i+1)
          if ret.shape[0]==0:
              continue
          ret = ret[ret[:,4]>0.6]
          for box in ret:
              # print('box',box)
              write_rect(img, box[:4], i, box[4])

      plt.imshow(img)
      _video.write(pil2cv(img))
      plt.pause(.01)
    else:
      break
_video.release()

こんな感じです。

まとめ

実運用を考えると、JetsonNanoでYOLOv3と速度比較したいです。

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6