基本トピック
PyTorch3Dの概要
PyTorch3D
はPyTorch
に基づく3D Computer Visionに用いられるライブラリです。三角メッシュの操作や微分可能レンダリングをPyTorch
の行列形式を用いて行うことが可能です。
PyTorch3Dのインストール
PyTorch3D
は下記のようにconda install
を実行することでインストールすることができます。
$ conda install pytorch3d
PyTorch3Dのサンプルコード
targetとsourceの用意
メッシュファイル(.obj)のロードとメッシュの構築(target)
まずは検証に用いるメッシュファイル(dolphin.obj
)の入手を行います。下記のコマンドを実行することでメッシュファイルを入手することができます。
$ wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj
メッシュファイル(.obj
)のロードは下記のようにpytorch3d.io.load_obj
を用いることで行うことができます。
from pytorch3d.io import load_obj
verts, faces, aux = load_obj('dolphin.obj')
faces_idx = faces.verts_idx
print(type(verts), verts.shape)
print(type(faces), type(faces_idx), faces_idx.shape)
・実行結果
<class 'torch.Tensor'> torch.Size([2562, 3])
<class 'pytorch3d.io.obj_io.Faces'> <class 'torch.Tensor'> torch.Size([5120, 3])
メッシュファイルはノードと3つのノードで構成される三角形によって構成されます。上記ではノードの位置を保持する行列がverts
、三角メッシュ(Triangular Mesh)を構成する3つのノードのインデックスがfaces_idx
に対応しています。
このように得たverts
とfaces_idx
を元に下記を実行することでターゲットのメッシュのtrg_mesh
を構築します(実際には原点を中心とする半径1の球に対応するように正規化・中心化なども行いますが当記事では省略します)。
from pytorch3d.structures import Meshes
...
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])
print(trg_mesh)
・実行結果
<pytorch3d.structures.meshes.Meshes object at 0x7f11799f1f70>
icosphere(ICO球)のメッシュ構築(source)
ICO球(icosphere)はUV球(uvsphere)と同様に球をメッシュ化する際に用いられる手法です。UV球が地球の経度と緯度のようにノードを得るのに対し、ICO球は正二十面体(Regular Icosahedron)を拡張することでノードを得ます。
UV球(左)やICO球(右) Blenderに関連して出てくることが多い
from pytorch3d.utils import ico_sphere
for i in range(6):
src_mesh = ico_sphere(level=i)
print(src_mesh.verts_packed().shape)
・実行結果
torch.Size([12, 3])
torch.Size([42, 3])
torch.Size([162, 3])
torch.Size([642, 3])
torch.Size([2562, 3])
torch.Size([10242, 3])
上記のノード数は$10 \cdot (2^i)^2 + 2$のように計算できることも合わせて抑えておくと良いと思います。
メッシュの変形とメッシュからの点群のサンプリング
メッシュの変形
メッシュを変形するにあたってはメッシュを構成する点(ノード)に対応する補正項(offset)を用いて点の位置に補正項を加算することで変形を行います。
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
new_src_mesh = src_mesh.offset_verts(deform_verts)
たとえば上記ではdeform_verts
が補正項であり、src_mesh.offset_verts(deform_verts)
を実行することでsrc_mesh
の各点にdeform_verts
による変換が反映されます。PyTorch3D
ではこの補正項のdeform_verts
をパラメータと見なし学習を行うことができます。
メッシュからの点群のサンプリング
メッシュからの点群のサンプリングは下記のようにpytorch3d.ops.sample_points_from_meshes
を実行すれば良いです。
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
src_mesh = ico_sphere(level=0)
sample_src = sample_points_from_meshes(new_src_mesh, 5000)
print(sample_src)
・実行結果
torch.Size([1, 5000, 3])
実行結果より5,000
個の点がサンプリングできていることが確認できます。また、点の可視化を行うと下記のような図が得られます。
最適化
lossの定義
最適化にあたってはDeepLearningの評価を行うにあたってloss
を定義する必要があります。点群の取り扱いにあたってはloss
にChamfer Distance(CD)が用いられることが多いです。
Chamfer Distanceは下記のようにpytorch3d.loss.chamfer_distance
を用いることで計算することができます。
from pytorch3d.utils import ico_sphere
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
sphere_mesh = ico_sphere(level=3)
verts, faces, _ = load_obj("dolphin.obj")
test_mesh = Meshes(verts=[verts], faces=[faces.verts_idx])
sample_sphere = sample_points_from_meshes(sphere_mesh, 5000)
sample_test = sample_points_from_meshes(test_mesh, 5000)
loss_chamfer, _ = chamfer_distance(sample_sphere, sample_test)
print(loss_chamfer.item())
・実行結果
1.0574...
pytorch3d.loss.chamfer_distance
の実行にあたっては筆者の環境では点が100,000
で1min程度かかり、概ね点の数の2乗だけ処理時間がかかったので点の数が多い場合の処理に関しては別途検討する必要があります。
パラメータのアップデート
パラメータの学習についてはtorch.optim.SGD
などのPyTorch
の機能を使うことで実装できます。PyTorch3D
はPyTorch
のネットワーク構築にあたっての途中処理をPyTorch
のテンソルの形式を用いながら代行してくれるのように理解すると良いのではないかと思います。