Help us understand the problem. What is going on with this article?

OpenPoseで姿勢推定(PyTorch)

目標

  1. OpenPoseの学習済みモデルをロードできるようになる
  2. OpenPoseの推論をできる

注意 本稿では扱わないこと

  1. 画像データセットから学習済みモデルを作成する
  2. パラメータをチューニングする

環境構築

conda 4.3.14
macos mojave

エラー

cv2がないエラー

conda install -c conda-forge opencv

torchvisionがないエラー

conda install torchvision -c pytorch

公式サイト
https://github.com/pytorch/vision
(できなくてpip installしてしまった。改善策を知りたい)

コード

パッケージインポート

from PIL import Image
import cv2
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
%matplotlib inline

import torch

学習済みモデルをロード

学習済みモデルは公開されているので、別途ダウンロードしておく。

from utils.openpose_net import OpenPoseNet

# 学習済みモデルと本章のモデルでネットワークの層の名前が違うので、対応させてロードする
# モデルの定義
net = OpenPoseNet()

# 学習済みパラメータをロードする
net_weights = torch.load(
    './weights/pose_model_scratch.pth', map_location={'cuda:0': 'cpu'})
keys = list(net_weights.keys())

weights_load = {}

# ロードした内容を、本書で構築したモデルの
# パラメータ名net.state_dict().keys()にコピーする
for i in range(len(keys)):
    weights_load[list(net.state_dict().keys())[i]
                 ] = net_weights[list(keys)[i]]

# コピーした内容をモデルに与える
state = net.state_dict()
state.update(weights_load)
net.load_state_dict(state)

print('ネットワーク設定完了:学習済みの重みをロードしました')

  • ここで、重みがオンラインでロードされるために多少時間がかかる。

画像の前処理

test_image = './data/04-drunk-train.jpg'

oriImg = cv2.imread(test_image)  # B,G,Rの順番

# BGRをRGBにして表示
oriImg = cv2.cvtColor(oriImg, cv2.COLOR_BGR2RGB)
plt.imshow(oriImg)
plt.show()

# 画像のリサイズ
size = (368, 368)
img = cv2.resize(oriImg, size, interpolation=cv2.INTER_CUBIC)

# 画像の前処理
img = img.astype(np.float32) / 255.

# 色情報の標準化
color_mean = [0.485, 0.456, 0.406]
color_std = [0.229, 0.224, 0.225]

preprocessed_img = img.copy()[:, :, ::-1]  # BGR→RGB

for i in range(3):
    preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - color_mean[i]
    preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / color_std[i]

# (高さ、幅、色)→(色、高さ、幅)
img = preprocessed_img.transpose((2, 0, 1)).astype(np.float32)

# 画像をTensorに
img = torch.from_numpy(img)

# ミニバッチ化:torch.Size([1, 3, 368, 368])
x = img.unsqueeze(0)

姿勢推定部分

# OpenPoseでheatmapsとPAFsを求めます
net.eval()
predicted_outputs, _ = net(x)

# 画像をテンソルからNumPyに変化し、サイズを戻します
pafs = predicted_outputs[0][0].detach().numpy().transpose(1, 2, 0)
heatmaps = predicted_outputs[1][0].detach().numpy().transpose(1, 2, 0)

pafs = cv2.resize(pafs, size, interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(heatmaps, size, interpolation=cv2.INTER_CUBIC)

pafs = cv2.resize(
    pafs, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(
    heatmaps, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)

heatmapの表示

heatmapは要は推定される体の部位を示すもの。

"""
// Result for BODY_25 (25 body parts consisting of COCO + foot)
// const std::map<unsigned int, std::string> POSE_BODY_25_BODY_PARTS {
//     {0,  "Nose"},
//     {1,  "Neck"},
//     {2,  "RShoulder"},
//     {3,  "RElbow"},
//     {4,  "RWrist"},
//     {5,  "LShoulder"},
//     {6,  "LElbow"},
//     {7,  "LWrist"},
//     {8,  "MidHip"},
//     {9,  "RHip"},
//     {10, "RKnee"},
//     {11, "RAnkle"},
//     {12, "LHip"},
//     {13, "LKnee"},
//     {14, "LAnkle"},
//     {15, "REye"},
//     {16, "LEye"},
//     {17, "REar"},
//     {18, "LEar"},
//     {19, "LBigToe"},
//     {20, "LSmallToe"},
//     {21, "LHeel"},
//     {22, "RBigToe"},
//     {23, "RSmallToe"},
//     {24, "RHeel"},
//     {25, "Background"}
// };
"""
# 25の全て
part = 0
while part <= 25:
    heat_map = heatmaps[:, :, part]  # 6は左肘
    heat_map = Image.fromarray(np.uint8(cm.jet(heat_map)*255))
    heat_map = np.asarray(heat_map.convert('RGB'))
# 合成して表示
    blend_img = cv2.addWeighted(oriImg, 0.5, heat_map, 0.5, 0)
    plt.imshow(blend_img)
    plt.show()
    part += 1

体の部位ごとの推定の出力結果


output_7_2.png

output_8_0.png
右肩
output_8_1.png
右肘
output_8_2.png
右手首
output_8_3.png
左肩
output_8_4.png

左肘
output_8_5.png

左手首
output_8_6.png

腰の中心
output_8_7.png

右膝
output_8_8.png
右足首
output_8_9.png
左の腰
output_8_10.png
左膝
output_8_11.png
左足首
output_8_12.png
output_8_13.png
output_8_14.png
output_8_15.png
output_8_16.png
output_8_17.png

姿勢推定結果の表示

from utils.decode_pose import decode_pose
_, result_img, _, _ = decode_pose(oriImg, heatmaps, pafs)

# 結果を描画
plt.imshow(oriImg)
plt.show()

plt.imshow(result_img)
plt.show()

output_10_0.png
output_10_1.png

出力結果の評価

推定結果は間違っている部分がある。

部位ごとのヒートマップを見ることで、推定結果がどのように誤解している結果、そうなったのか理解できる。

参考文献

https://github.com/YutaroOgawa/pytorch_advanced

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした