LoginSignup
3
6

More than 1 year has passed since last update.

【お絵かき初心者必見】SPINで写真から人間の3Dを取得する

Last updated at Posted at 2021-07-28

https://qiita.com/akaiteto/items/6a59a4b644fdf6d7bc4d
https://qiita.com/akaiteto/items/b5c8c3d5eb5ca3849c5d

何故かシリーズ化している本シリーズ。
要するに、人間の姿勢推定、骨格推定の話です。
とりあえず最低限の形にはなったのでこの記事で人間3D化は一旦終了です。

この記事ではSPINというアルゴリズムを試します。ちゃんと動くソースもできたし割と満足。
SPINは単一の画像から骨格を検出するアルゴリズムです。
もともと画像単体に対して実行するものですが、MP4の動画を読み込んだ結果はこんな感じ。
af8d3643-f626-47dd-9679-ddc29680eb7c.gif
TCMRとSPINのソースをガチャガチャ組み合わせてうまくいきました
テレビを撮影したような動画みたいですけど、うまいこと推定されてますね。すげー
image.png

ついでにこちらが画像じゃなく3D化したもの。なんのポーズかはしらない。
それから、これが果たしてお絵かきに役立つかもわからない。

SPIN

この記事では、アルゴリズムの方にはあまり触れません。(TCMRよむのに精魂尽きた)
でも別の記事でSPINのソースを基にSMPL周りの計算式を整理した記事はあげるとおもう。

TCMRの余談(よまなくてもいい)

https://qiita.com/akaiteto/items/6a59a4b644fdf6d7bc4d
前回TCMRを動かした時、デモソースが上手く動かず、TCMRのRegressorの部分でエラーが出ていました。
向こうの記事にはあえて反映させませんが、エラーの原因は結論から言えばSMPLのバージョンです。

私が入力としてあたえているものはSMPLのモデルだけなので、十中八九SMPLのモデルが原因。
ソースを見た感じ、TCMRはSPINのソースをベースにして作られているようなので、
SPINのSMPLの説明をみてみると、TCMRとSPINとでpklのモデルファイルのバージョンが違いました。

もしやSMPLのモデルのバージョンが違う…?
試しにモデルを変えて実行したらうまくいきましたやったー。
SMPLをなにかのOSSで使う時はバージョン要注意です。

前提条件

TCMRのソースをベースにします。

環境構築等は前回参照。
https://qiita.com/akaiteto/items/6a59a4b644fdf6d7bc4d
前回記事読まなくても、google colabで以下を実行すればいけるかも?

#環境構築
!git clone https://github.com/hongsukchoi/TCMR_RELEASE
!pip install numpy==1.17.5 torch==1.4.0 torchvision==0.5.0
!pip install git+https://github.com/giacaglia/pytube.git --upgrade
!pip install -r TCMR_RELEASE/requirements.txt
!pip install open3D

#出力先設定
! mkdir output ; cd output ; mkdir demo_output
! mkdir data ; cd data ; mkdir base_data
# ※googledriveにSMPLのモデル配置前提
! cp drive/MyDrive/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl data/base_data/SMPL_NEUTRAL.pkl

#デモ動画/学習済みモデルダウンロード
!source TCMR_RELEASE/scripts/get_base_data.sh

余談(よまなくていい)

https://qiita.com/akaiteto/items/6a59a4b644fdf6d7bc4d
前回の記事を見ている方ならご存知だと思いますが、
SPINはTCMRのデモソースで採用されていた技術の1つです。

一応TCMRのおさらいですが、
TCMRは動画に特化して人の動きを検出する技術です。

動画の各フレームごとに画像単一の骨格特徴の抽出を行い、
全フレームの骨格特徴を現在過去未来とデータを分けて学習することで、
動画全体の骨格の流れを滑らかに調整します。

SPINは、この一連の処理の中で、はじめに画像単一で骨格を抽出する際に使われていました。

SMPLの用意

https://smplify.is.tue.mpg.de/
まず、SMPLのモデルをダウンロードします。
アカウント登録をして、メインメニューのDownloadタブからダウンロードしてください。

前回のTCMRの記事を読んでいる人はここがめちゃくちゃ重要です。
実はSMPLには複数のダウンロードサイトがあります。
そして、サイトに寄ってバージョンが異なる場合があるようです。(ここ重要)

上記サイトからダウンロードしたモデルは、2021年7月時点では
「basicModel_neutral_lbs_10_207_0_v1.0.0.pkl」というファイルです。
OSSによって、このpklに対応したバージョンじゃないと動かない場合があるので
もしもSMPLのモデルの読み込みが関わる処理でエラーが出た時は、
モデルのバージョンを変えてみましょう。(なお、SMPL公式ではバージョン管理していない模様…)

ソース

TCMR内にあるSPINは出力された特徴を変換したりしているのでガチャガチャ改造します。
以下のようにdemo.pyをわりとガッツリ書き換えます。
ファイルを分けたりなどやってないど、くっっっっっそながいです。

demo.py

import os
import os.path as osp
from lib.core.config import BASE_DATA_DIR
from lib.models.smpl import SMPL, SMPL_MODEL_DIR
os.environ['PYOPENGL_PLATFORM'] = 'egl'

import cv2
import time
import torch
import joblib
import shutil
import colorsys
import argparse
import random
import numpy as np
from pathlib import Path
from tqdm import tqdm
from multi_person_tracker import MPT
from torch.utils.data import DataLoader

from lib.models.tcmr import TCMR
from lib.utils.renderer import Renderer
from lib.dataset._dataset_demo import CropDataset, FeatureDataset
from lib.utils.demo_utils import (
    download_youtube_clip,
    convert_crop_cam_to_orig_img,
    prepare_rendering_results,
    video_to_images,
    images_to_video,
)
from PIL import Image
MIN_NUM_FRAMES = 25
random.seed(1)
torch.manual_seed(1)
np.random.seed(1)

import os
import torch
from torchvision.utils import make_grid
import numpy as np
import pyrender
import trimesh


import os
import cv2
import numpy as np
import os.path as osp
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor

from lib.utils.smooth_bbox import get_all_bbox_params
from lib.data_utils._img_utils import get_single_image_crop_demo



import torch
import torch.nn as nn
import torchvision.models.resnet as resnet
import numpy as np
import math

from torch.nn import functional as F

def rot6d_to_rotmat(x):
    """Convert 6D rotation representation to 3x3 rotation matrix.
    Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
    Input:
        (B,6) Batch of 6-D rotation representations
    Output:
        (B,3,3) Batch of corresponding rotation matrices
    """
    x = x.view(-1,3,2)
    a1 = x[:, :, 0]
    a2 = x[:, :, 1]
    b1 = F.normalize(a1)
    b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
    b3 = torch.cross(b1, b2)
    return torch.stack((b1, b2, b3), dim=-1)

class Bottleneck(nn.Module):
    """ Redefinition of Bottleneck residual block
        Adapted from the official PyTorch implementation
    """
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class HMR_SPIN(nn.Module):
    """ SMPL Iterative Regressor with ResNet50 backbone
    """

    def __init__(self, block, layers, smpl_mean_params):
        self.inplanes = 64
        super(HMR_SPIN, self).__init__()
        npose = 24 * 6
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc1 = nn.Linear(512 * block.expansion + npose + 13, 1024)
        self.drop1 = nn.Dropout()
        self.fc2 = nn.Linear(1024, 1024)
        self.drop2 = nn.Dropout()
        self.decpose = nn.Linear(1024, npose)
        self.decshape = nn.Linear(1024, 10)
        self.deccam = nn.Linear(1024, 3)
        nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
        nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
        nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        mean_params = np.load(smpl_mean_params)
        init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
        init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
        init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
        self.register_buffer('init_pose', init_pose)
        self.register_buffer('init_shape', init_shape)
        self.register_buffer('init_cam', init_cam)


    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)


    def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for i in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam],1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)

        return pred_rotmat, pred_shape, pred_cam

def hmr_SPIN(smpl_mean_params, pretrained=True, **kwargs):
    """ Constructs an HMR model with ResNet50 backbone.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = HMR_SPIN(Bottleneck, [3, 4, 6, 3],  smpl_mean_params, **kwargs)
    if pretrained:
        resnet_imagenet = resnet.resnet50(pretrained=True)
        model.load_state_dict(resnet_imagenet.state_dict(),strict=False)
    return model

class CropDataset_SPIN(Dataset):
    def __init__(self, image_folder, frames, bboxes=None, joints2d=None, scale=1.0, crop_size=224):
        self.image_file_names = [
            osp.join(image_folder, x)
            for x in os.listdir(image_folder)
            if x.endswith('.png') or x.endswith('.jpg')
        ]
        self.image_file_names = sorted(self.image_file_names)
        self.image_file_names = np.array(self.image_file_names)[frames]
        self.bboxes = bboxes
        self.joints2d = joints2d
        self.scale = scale
        self.crop_size = crop_size
        self.frames = frames
        self.has_keypoints = True if joints2d is not None else False

        self.norm_joints2d = np.zeros_like(self.joints2d)

        if self.has_keypoints:
            bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints2d, vis_thresh=0.3)
            bboxes[:, 2:] = 150. / bboxes[:, 2:]
            self.bboxes = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T

            self.image_file_names = self.image_file_names[time_pt1:time_pt2]
            self.joints2d = joints2d[time_pt1:time_pt2]
            self.frames = frames[time_pt1:time_pt2]

    def __len__(self):
        return len(self.image_file_names)

    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_file_names[idx]), cv2.COLOR_BGR2RGB)

        bbox = self.bboxes[idx]

        j2d = self.joints2d[idx] if self.has_keypoints else None

        norm_img, raw_img, kp_2d = get_single_image_crop_demo(
            img,
            bbox,
            kp_2d=j2d,
            scale=self.scale,
            crop_size=self.crop_size)
        if self.has_keypoints:
            return norm_img, kp_2d,raw_img
        else:
            return norm_img,raw_img


class RendererSPIN:
    """
    Renderer used for visualizing the SMPL model
    Code adapted from https://github.com/vchoutas/smplify-x
    """
    def __init__(self, focal_length=5000, img_res=224, faces=None):
        self.renderer = pyrender.OffscreenRenderer(viewport_width=img_res,
                                       viewport_height=img_res,
                                       point_size=1.0)
        self.focal_length = focal_length
        self.camera_center = [img_res // 2, img_res // 2]
        self.faces = faces

    def visualize_tb(self, vertices, camera_translation, images):
        vertices = vertices.cpu().numpy()
        camera_translation = camera_translation.cpu().numpy()
        images = images.cpu()
        images_np = np.transpose(images.numpy(), (0,2,3,1))
        rend_imgs = []
        for i in range(vertices.shape[0]):
            rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i]), (2,0,1))).float()
            rend_imgs.append(images[i])
            rend_imgs.append(rend_img)
        rend_imgs = make_grid(rend_imgs, nrow=2)
        return rend_imgs

    def __call__(self, vertices, camera_translation, image):
        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.2,
            alphaMode='OPAQUE',
            baseColorFactor=(0.8, 0.3, 0.3, 1.0))

        camera_translation[0] *= -1.

        mesh = trimesh.Trimesh(vertices, self.faces)
        rot = trimesh.transformations.rotation_matrix(
            np.radians(180), [1, 0, 0])
        mesh.apply_transform(rot)
        mesh = pyrender.Mesh.from_trimesh(mesh, material=material)

        scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))
        scene.add(mesh, 'mesh')

        camera_pose = np.eye(4)
        camera_pose[:3, 3] = camera_translation
        camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
                                           cx=self.camera_center[0], cy=self.camera_center[1])
        scene.add(camera, pose=camera_pose)


        light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)
        light_pose = np.eye(4)

        light_pose[:3, 3] = np.array([0, -1, 1])
        scene.add(light, pose=light_pose)

        light_pose[:3, 3] = np.array([0, 1, 1])
        scene.add(light, pose=light_pose)

        light_pose[:3, 3] = np.array([1, 1, 2])
        scene.add(light, pose=light_pose)

        color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0
        valid_mask = (rend_depth > 0)[:,:,None]
        output_img = (color[:, :, :3] * valid_mask +
                  (1 - valid_mask) * image)

        return output_img

def main(args):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    """ Prepare input video (images) """
    video_file = args.vid_file
    if video_file.startswith('https://www.youtube.com'):
        print(f"Donwloading YouTube video \'{video_file}\'")
        video_file = download_youtube_clip(video_file, '/tmp')
        if video_file is None:
            exit('Youtube url is not valid!')
        print(f"YouTube Video has been downloaded to {video_file}...")

    if not os.path.isfile(video_file):
        exit(f"Input video \'{video_file}\' does not exist!")

    output_path = osp.join('./output/demo_output', os.path.basename(video_file).replace('.mp4', ''))
    Path(output_path).mkdir(parents=True, exist_ok=True)
    image_folder, num_frames, img_shape = video_to_images(video_file, return_info=True)

    print(f"Input video number of frames {num_frames}\n")
    orig_height, orig_width = img_shape[:2]

    """ Run tracking """
    total_time = time.time()
    bbox_scale = 1.2
    # 動画から物体検出
    # 全フレームに対して写っている人を識別してその人のbboxを取得する。
    # https://github.com/mkocabas/multi-person-tracker.git
    mot = MPT(
        device=device,
        batch_size=args.tracker_batch_size,
        display=args.display,
        detector_type=args.detector,
        output_format='dict',
        yolo_img_size=args.yolo_img_size,
    )
    tracking_results = mot(image_folder)
    # 動画にあまり写っていない人(=フレーム数が少ない人)は削除
    for person_id in list(tracking_results.keys()):
        if tracking_results[person_id]['frames'].shape[0] < MIN_NUM_FRAMES:
            del tracking_results[person_id]

    # SPINモデルの準備 -> 学習済みモデルの読み込み
    # from lib.models.spin import hmr
    SMPL_MEAN_PARAMS = osp.join(BASE_DATA_DIR, 'smpl_mean_params.npz')
    hmr = hmr_SPIN(SMPL_MEAN_PARAMS).to(device)
    checkpoint = torch.load(osp.join(BASE_DATA_DIR, 'spin_model_checkpoint.pth.tar'))
    hmr.load_state_dict(checkpoint['model'], strict=False)
    hmr.eval()

    # SMPLモデルを操作するための準備
    smpl = SMPL(SMPL_MODEL_DIR,
                batch_size=1,
                create_transl=False).to(device)
    FOCAL_LENGTH = 5000
    IMG_RES = 224
    renderer_spin = RendererSPIN(focal_length=FOCAL_LENGTH, img_res=IMG_RES, faces=smpl.faces)

    # 検出した人ごとにループ
    dir_output = "output/person"
    os.makedirs(dir_output, exist_ok=True)
    framecnt = 0
    for person_id in tqdm(list(tracking_results.keys())):
        outimg_files = []
        framecnt += 1

        # 出力先の設定
        os.makedirs(dir_output + "/" + str(person_id), exist_ok=True)
        outputfile_mp4 = dir_output + "/" + str(person_id) + ".mp4"
        outputfile_gif = dir_output + "/" + str(person_id) + ".gif"

        # 実行した物体検出の結果を準備
        # tracking_resultsにはその人の全フレームデータが全て入ってる。
        bboxes = joints2d = None
        bboxes = tracking_results[person_id]['bbox']
        frames = tracking_results[person_id]['frames']
        # bboxに基づくトリミング+ネットワークに合わせた形式への変換
        dataset = CropDataset_SPIN(
            image_folder=image_folder,
            frames=frames,
            bboxes=bboxes,
            joints2d=joints2d,
            scale=bbox_scale,
        )
        bboxes = dataset.bboxes
        frames = dataset.frames
        has_keypoints = True if joints2d is not None else False
        crop_dataloader = DataLoader(dataset, batch_size=256, num_workers=16)

        for i, batch_src in enumerate(crop_dataloader):
          if has_keypoints:
            batch_src, nj2d , img_raw_all_frame= batch_src
            norm_joints2d.append(nj2d.numpy().reshape(-1, 21, 3))
          else:
            batch_src , img_raw_all_frame= batch_src

          # レンダリング
          batch_size = batch_src.size()
          for idx in range(img_raw_all_frame.size()[0]) :
            # batchの次元        [その人が出現するフレーム数,RGB=3,横  ,縦]
            # img_raw_all_frameの次元  [その人が出現するフレーム数,横  ,縦  ,RGB=3]
            batch = batch_src[idx].view(1,batch_size[1],batch_size[2],batch_size[3])
            with torch.no_grad():
              batch = batch.to(device)
              pred_rotmat, pred_betas, pred_camera = hmr(batch)
              pred_output = smpl(betas=pred_betas, body_pose=pred_rotmat[:,1:], global_orient=pred_rotmat[:,0].unsqueeze(1), pose2rot=False)
              pred_vertices = pred_output.vertices

            # 出力値からSMPL向けの値を取得
            camera_translation = torch.stack([pred_camera[:,1], pred_camera[:,2], 2*FOCAL_LENGTH/(IMG_RES * pred_camera[:,0] +1e-9)],dim=-1)
            camera_translation = camera_translation[0].cpu().numpy()
            pred_vertices = pred_vertices[0].cpu().numpy()

            # SMPLのモデルの画像生成
            img_raw_frame = img_raw_all_frame[idx].cpu().numpy()
            img_smpl = renderer_spin(pred_vertices, camera_translation, img_raw_frame)
            img_raw_frame = 255 * img_raw_frame[:,:,::-1]
            img_smpl = 255 * img_smpl[:,:,::-1]

            # 保存(SMPLを通常画像の上に合成するために無駄にごちゃごちゃしたことをしている。)
            outputfile_smpl = dir_output + "/smpl_" + str(idx) + ".png"
            cv2.imwrite(outputfile_smpl, img_smpl)
            img_smpl = cv2.imread(outputfile_smpl, 1)
            img2_gray = cv2.cvtColor(img_smpl, cv2.COLOR_BGR2GRAY)
            img_maskg = cv2.threshold(img2_gray, 220, 255, cv2.THRESH_BINARY_INV)[1]
            img_mask = cv2.merge((img_maskg,img_maskg, img_maskg))
            img_smpl = cv2.bitwise_and(img_smpl, img_mask)
            img_maskn = cv2.bitwise_not(img_mask)
            img_src1m = cv2.bitwise_and(img_raw_frame, img_maskn)
            img_dst = cv2.bitwise_or(img_src1m, img_smpl)
            cv2.imwrite(outputfile_smpl, img_dst)

            outimg_files.append(outputfile_smpl)

          # 動画作成(gif)
          images_PIL = []
          for img_file in outimg_files:
            im = Image.open(img_file)
            images_PIL.append(im)
          images_PIL[0].save(outputfile_gif, save_all=True, append_images=images_PIL[1:], loop=0, duration=30)

          # 動画作成(mp4)
          fourcc = cv2.VideoWriter_fourcc('m','p','4', 'v')
          video  = cv2.VideoWriter(outputfile_mp4, fourcc, 20.0, (IMG_RES, IMG_RES))
          for img_file in outimg_files:
            img = cv2.imread(img_file)
            video.write(img)
          exit()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--vid_file', type=str, default='sample_video.mp4', help='input video path or youtube link')

    parser.add_argument('--model', type=str, default='./data/base_data/tcmr_demo_model.pth.tar', help='path to pretrained model weight')

    parser.add_argument('--detector', type=str, default='yolo', choices=['yolo', 'maskrcnn'],
                        help='object detector to be used for bbox tracking')

    parser.add_argument('--yolo_img_size', type=int, default=416,
                        help='input image size for yolo detector')

    parser.add_argument('--tracker_batch_size', type=int, default=12,
                        help='batch size of object detector used for bbox tracking')

    parser.add_argument('--display', action='store_true',
                        help='visualize the results of each step during demo')

    parser.add_argument('--save_pkl', action='store_true',
                        help='save results to a pkl file')

    parser.add_argument('--save_obj', action='store_true',
                        help='save results as .obj files.')

    parser.add_argument('--gender', type=str, default='neutral',
                        help='set gender of people from (neutral, male, female)')

    parser.add_argument('--wireframe', action='store_true',
                        help='render all meshes as wireframes.')

    parser.add_argument('--sideview', action='store_true',
                        help='render meshes from alternate viewpoint.')

    parser.add_argument('--render_plain', action='store_true',
                        help='render meshes on plain background')

    parser.add_argument('--gpu', type=int, default='1', help='gpu num')

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    main(args)

結果

af8d3643-f626-47dd-9679-ddc29680eb7c.gif
わーい。うまくいきましたー。

まとめ

TCMR、SPINのおかげで、SMPLを前提にした操作や変換がわかるようになりました。
(そのうち記事にしたい。)
SPINでも普通に精度良さげなので、大満足の結果です。

test.py
            # 3D出力
            o3d_vertices = pred_output.vertices.detach().cpu().numpy().squeeze()
            import open3d as o3d
            mesh = o3d.geometry.TriangleMesh()
            mesh.vertices = o3d.utility.Vector3dVector(o3d_vertices)

            smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
            mesh.triangles = o3d.utility.Vector3iVector(smpl.faces)
            mesh.compute_vertex_normals()
            mesh.paint_uniform_color([0.3, 0.3, 0.3])
            o3d.visualization.draw_geometries([mesh])
            o3d.io.write_triangle_mesh('smplx_torch_neutral.obj', mesh)

ついでに、3Dモデルも出力するようにしたので、
一番手に入れたかった3Dデータも手に入りましたわーい。
image.png

余談

今回の記事は本当に色々勉強になりました。
TCMRでつまったところを解消できたのがとても大きい。
以下、今回から学んだ教訓。
1.OSSを試そうとした時に上手く動かない時は、
  そのOSSがテンプレートとして使ってるOSSを調べてみる
  ベースが有る時はソースに "http://github~~" と参照ついてる場合が多いので、
  そこから探りましょう。

2.OSSの使い方がわからない・OSSの下準備が面倒なときは、
  githubからそのOSSを探して、下準備もろもろやってくれてるやつを探す。

今後

とまぁ、人を3D化したわけですが、
正直人の人体だけ3Dにできても役に立たないので、いずれは背景込みで3D化したいですね。
それではー

3
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
6