LoginSignup
11

More than 3 years have passed since last update.

OpenPoseで姿勢推定(PyTorch)

Last updated at Posted at 2019-10-22

目標

  1. OpenPoseの学習済みモデルをロードできるようになる
  2. OpenPoseの推論をできる

注意 本稿では扱わないこと

  1. 画像データセットから学習済みモデルを作成する
  2. パラメータをチューニングする

環境構築

conda 4.3.14
macos mojave

エラー

cv2がないエラー

conda install -c conda-forge opencv

torchvisionがないエラー

conda install torchvision -c pytorch

公式サイト
https://github.com/pytorch/vision
(できなくてpip installしてしまった。改善策を知りたい)

コード

パッケージインポート

from PIL import Image
import cv2
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
%matplotlib inline

import torch

学習済みモデルをロード

学習済みモデルは公開されているので、別途ダウンロードしておく。

from utils.openpose_net import OpenPoseNet

# 学習済みモデルと本章のモデルでネットワークの層の名前が違うので、対応させてロードする
# モデルの定義
net = OpenPoseNet()

# 学習済みパラメータをロードする
net_weights = torch.load(
    './weights/pose_model_scratch.pth', map_location={'cuda:0': 'cpu'})
keys = list(net_weights.keys())

weights_load = {}

# ロードした内容を、本書で構築したモデルの
# パラメータ名net.state_dict().keys()にコピーする
for i in range(len(keys)):
    weights_load[list(net.state_dict().keys())[i]
                 ] = net_weights[list(keys)[i]]

# コピーした内容をモデルに与える
state = net.state_dict()
state.update(weights_load)
net.load_state_dict(state)

print('ネットワーク設定完了:学習済みの重みをロードしました')

  • ここで、重みがオンラインでロードされるために多少時間がかかる。

画像の前処理

test_image = './data/04-drunk-train.jpg'

oriImg = cv2.imread(test_image)  # B,G,Rの順番

# BGRをRGBにして表示
oriImg = cv2.cvtColor(oriImg, cv2.COLOR_BGR2RGB)
plt.imshow(oriImg)
plt.show()

# 画像のリサイズ
size = (368, 368)
img = cv2.resize(oriImg, size, interpolation=cv2.INTER_CUBIC)

# 画像の前処理
img = img.astype(np.float32) / 255.

# 色情報の標準化
color_mean = [0.485, 0.456, 0.406]
color_std = [0.229, 0.224, 0.225]

preprocessed_img = img.copy()[:, :, ::-1]  # BGR→RGB

for i in range(3):
    preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - color_mean[i]
    preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / color_std[i]

# (高さ、幅、色)→(色、高さ、幅)
img = preprocessed_img.transpose((2, 0, 1)).astype(np.float32)

# 画像をTensorに
img = torch.from_numpy(img)

# ミニバッチ化:torch.Size([1, 3, 368, 368])
x = img.unsqueeze(0)

姿勢推定部分

# OpenPoseでheatmapsとPAFsを求めます
net.eval()
predicted_outputs, _ = net(x)

# 画像をテンソルからNumPyに変化し、サイズを戻します
pafs = predicted_outputs[0][0].detach().numpy().transpose(1, 2, 0)
heatmaps = predicted_outputs[1][0].detach().numpy().transpose(1, 2, 0)

pafs = cv2.resize(pafs, size, interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(heatmaps, size, interpolation=cv2.INTER_CUBIC)

pafs = cv2.resize(
    pafs, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(
    heatmaps, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)

heatmapの表示

heatmapは要は推定される体の部位を示すもの。

"""
// Result for BODY_25 (25 body parts consisting of COCO + foot)
// const std::map<unsigned int, std::string> POSE_BODY_25_BODY_PARTS {
//     {0,  "Nose"},
//     {1,  "Neck"},
//     {2,  "RShoulder"},
//     {3,  "RElbow"},
//     {4,  "RWrist"},
//     {5,  "LShoulder"},
//     {6,  "LElbow"},
//     {7,  "LWrist"},
//     {8,  "MidHip"},
//     {9,  "RHip"},
//     {10, "RKnee"},
//     {11, "RAnkle"},
//     {12, "LHip"},
//     {13, "LKnee"},
//     {14, "LAnkle"},
//     {15, "REye"},
//     {16, "LEye"},
//     {17, "REar"},
//     {18, "LEar"},
//     {19, "LBigToe"},
//     {20, "LSmallToe"},
//     {21, "LHeel"},
//     {22, "RBigToe"},
//     {23, "RSmallToe"},
//     {24, "RHeel"},
//     {25, "Background"}
// };
"""
# 25の全て
part = 0
while part <= 25:
    heat_map = heatmaps[:, :, part]  # 6は左肘
    heat_map = Image.fromarray(np.uint8(cm.jet(heat_map)*255))
    heat_map = np.asarray(heat_map.convert('RGB'))
# 合成して表示
    blend_img = cv2.addWeighted(oriImg, 0.5, heat_map, 0.5, 0)
    plt.imshow(blend_img)
    plt.show()
    part += 1

体の部位ごとの推定の出力結果


output_7_2.png

output_8_0.png
右肩
output_8_1.png
右肘
output_8_2.png
右手首
output_8_3.png
左肩
output_8_4.png

左肘
output_8_5.png

左手首
output_8_6.png

腰の中心
output_8_7.png

右膝
output_8_8.png
右足首
output_8_9.png
左の腰
output_8_10.png
左膝
output_8_11.png
左足首
output_8_12.png
output_8_13.png
output_8_14.png
output_8_15.png
output_8_16.png
output_8_17.png

姿勢推定結果の表示

from utils.decode_pose import decode_pose
_, result_img, _, _ = decode_pose(oriImg, heatmaps, pafs)

# 結果を描画
plt.imshow(oriImg)
plt.show()

plt.imshow(result_img)
plt.show()

output_10_0.png
output_10_1.png

出力結果の評価

推定結果は間違っている部分がある。

部位ごとのヒートマップを見ることで、推定結果がどのように誤解している結果、そうなったのか理解できる。

参考文献

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11