ニューラルネットワークを移植する
Libtorchのコンパイルとリンキングが上手くいったので、Pytorchのニューラルネットワークを移植してみる。
↓こんなやつ。
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5)
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 16, 64)
self.fc2 = torch.nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = torch.nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = x.reshape(-1, 16 * 16)
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x
なんとなく分かるけど、コピペでコンパイルしてポンというわけには行かなさそう。defが関数宣言なのは昭和のおじさんでも知っている。それにしてもメンバ変数の宣言が無いのはどういうことなのか?
Pythonを読む
とにかく1行ずつ読んでみる。
class Net(torch.nn.Module):
↑クラス宣言のようだ。torch.nn.Moduleは親クラスらしい。
def __init__(self):
↑コンストラクタの宣言か。selfってなに?
super().__init__()
↑superってなんやねん? Google先生によると親クラスのコンストラクタを呼んでいるらしい。引数が無いことろを見るとデフォルトコンストラクタか。
self.conv1 = torch.nn.Conv2d(1, 6, 5)
↑変数宣言が無いのに、いきなり代入式が。なんやねん意味わからん腹立つ。指示器無しで前に割り込みされたような不快さだ。コンストラクタの中で変数宣言しておくと、自動的にクラスのスコープで使える?と勝手に解釈。
torch.nn.Conv2d()はクラスのコンストラクタらしい。ということはconv1の型指定はtorch::nn::Conv2dクラスを指定すればいいと言う事かな。
def forward(self, x):
↑メンバ関数らしい。だからselfってなんやねん。いちいち引数に入れるんか。selfを自クラスを現すポインタみたいなもの? C++のthisみたいなものと勝手に解釈。xって何?
x = self.conv1(x)
x = torch.nn.functional.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = x.reshape(-1, 16 * 16)
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
↑xの型が不明だが、テンソルか何かを含む構造体なのだろう。各行で行列演算してxに格納するらしい。
return x
↑まあ、これは分かる。
C++に置き換える
まあC++には変換できそうな感じ。クラスとコンストラクタの部分をC++に置き換える。とりあえず全部public。↓
//C++
class Net : torch::nn::Module
{
public:
torch::nn::Conv2d conv1;
torch::nn::MaxPool2d pool;
torch::nn::Conv2d conv2;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
Net();
};
Net::Net()
{
torch::nn::Module();
conv1 = torch::nn::Conv2d(1, 6, 5);
pool = torch::nn::MaxPool2d(2, 2);
conv2 = torch::nn::Conv2d(6, 16, 5);
fc1 = torch::nn::Linear(16 * 16, 64);
fc2 = torch::nn::Linear(64, 10);
}
文法エラーは出てこない。コンパイルしてみる。
エラー C2338がでる。IDEのリンク先を見ても良く分からない。よくあるマイクロソフトの斜め上な解説。チンプンカンプン。厳密にはエラーではなくワーニングのようだが致命的とある。
気を取り直してIDEの英語のエラーメッセージをよく読む。
↑no default constructorとある。ははーん、おそらく親クラスかメンバ変数の中にテンプレート宣言か何かがあってインスタンス生成時に明示的にコンストラクタを呼ぶ必要があるのだろう。要するに動的メモリ確保すればいいわけだ。
今日はここまで
つづく