LoginSignup
1
0

More than 1 year has passed since last update.

Pythonが嫌いなのでC++版のPytorchで画像認識をやってみる 7日目

Last updated at Posted at 2022-08-16

PythonのConv2dの定義をC++に置き換える

 前回より日が経ってしまった。別件でC++を思い出すのを兼ねて監視カメラの録画アプリを残業して作ってたのもある。
 閑話休題、今日から畳み込みのニューラルネットワークを作っていく。実は作るNNには元になるpythonのコードがあるのだが頂きものなので残念ながらまるまる載せるわけにはいかないのでご容赦。つまずきポイントをメモする感じで書き残す。

Conv2d

Pythonだと↓こんな感じで書くやつである。

nn.Conv2d(inp, oup, 1, stride, 0, bias = False),

何をするものなのかと書こうかと思ったけど↓ここに素晴らしい解説があるので割愛。

kerasの定義だとこんな感じ。

keras.layers.Conv2D
keras.layers.Conv2D(
  filters,
  kernel_size,
  strides=(1, 1),
  padding='valid',
  data_format=None,
  dilation_rate=(1, 1),
  activation=None,
  use_bias=True,
  kernel_initializer='glorot_uniform',
  bias_initializer='zeros',
  kernel_regularizer=None,
  bias_regularizer=None,
  activity_regularizer=None,
  kernel_constraint=None,
  bias_constraint=None
)

pytorchの定義だと↓こんな感じ。

CLASS torch.nn.Conv2d
CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, 
padding=0, dilation=1, groups=1, bias=True, 
padding_mode='zeros', device=None, dtype=None)

どちらにしても引数がいっぱいある。この引数を組み合わせて必要なテンソルを作っていくわけだ。
さてC++のLibtorchだとどうなるかというと…

https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_conv2d.html#classtorch_1_1nn_1_1_conv2d
CLASS torch::nn::Conv2d : PUBLIC torch::nn::ModuleHolder<Conv2dImpl>

何だか分からんがな。torch::nn::ModuleHolderというテンプレートクラスにConv2dImplという型がはめ込んである。インスタンス生成にどんなパラメータが入るのか見たいのでConv2dImplの公式解説を見てみる。

Conv2dImpl

Conv2dImplコンストラクタ
Conv2dImpl(int64_t input_channels, int64_t output_channels, ExpandingArray<2> kernel_size);

コンストラクタの引数が3つしかないやんけ。実はコンストラクタはもう一つある↓。

Conv2dImplコンストラクタ
Conv2dImpl(Conv2dOptions options_);

Conv2dOptions

Conv2dOptionsというクラスを利用して引数を渡せということらしい。公式にはこんなサンプルが出ている。

Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));

メンバ関数を追加して必要な引数を追加していく仕組みらしい。どういうテクニック(テンプレートの回帰処理??)でこんな機構を作っているのか分からないが、これをpythonと比較すると今感じになるようだ。

Python
model=Conv2d (3, 2, 3,stride=1, bias=false)
↓
C++
Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));

まあ似ていると言えば似ているけど、こんなに似せないといけないものだろうか。引数が多くなるのでこんな仕組みにしたのだろうけどC++らしくないような。引数はパブリックなメンバ変数にしてinit()で初期化でいいんじゃないの? 1行でテンソルを定義できるのは良いけどデバッグが面倒臭そう。昭和のおじさんはトリッキーなコードが嫌いなのだ。が仕方がない、

リンク

ここらで参照しているリンクをまとめてみる。

Libtorch公式サイト

Pytorch公式サイト

公式サンプルコード?

日本語Libtorch解説

PyTorchで学習したモデルをC++から使う

今日はここまで

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0