module.hは後回し
またまた時間が経ってしまった。毎日更新とはいかないな。今日は大雨で大変だった。家の前の道路が冠水しとる。それはともかくforward関数がどうやって呼び出されるのかソースでは俄かに分かりそうにない。exeが走るようになったらデバッガで調べてやろう。ということで「どうやって呼び出されるか問題」は後回し。
forward関数とは何ぞや
"must override"とmodule.hに書いてあるのでコンストラクタに匹敵する重要な関数なのだろう。ニューラルネットのクラスなので大体想像つくが調べることにする。公式ドキュメントにはこう書いてある。
「すべての呼び出しで実行される計算を定義します。すべてのサブクラスでオーバーライドする必要があります。
ノート:フォワード パスのレシピはこの関数内で定義する必要がありますが、後でこれの代わりに Moduleインスタンスを呼び出す必要があります。前者は登録されたフックの実行を処理し、後者は静かにそれらを無視するためです。」
分かりにくい英語だが、名前から想像つくようにニューラルネットのforward処理を定義する関数ということだろう。昭和のおじさんはこれでも大学でAIを習ったのだ。30年前だが。当時はパーセプトロンと言ってせいぜい入力層、中間層、出力層の3つの層しかなかったけどね。AIの説明はここが分かりやすかったので載せておく。https://www.hpc.co.jp/library/commentary/aboutdeeplearning/
torch::Tensor forward(torch::Tensor input);
引数のinputは入力層に入れるテンソルで間違いなさそうだが、torch::Tensorがポインタ(ポインタ定義をオーバーライドした型)なのかインスタンスなのか気になるなあ。公式ドキュメントにはこう書いてあるので、インスタンス型でいいようだ。名前空間が違うのが気になるがいくら検索してもtorch::Tensorというのは出てこないので何か謎の仕組みでオーバーライドされているのかもしれない。これもexeができたらデバッガで調べよう。
謎の構文
modele.hに戻って続きを読んでいくと、謎の構文が現れる。なんとclass定義の中でメンバ変数が初期化されている。そんなんアリか・・・
private:
Conv2d conv1 = nullptr;
BatchNorm2d bn1 = nullptr;
まあ意味は分かるが、nullptrはヌルポか。とするとConv2dというのはポインタ型? けどmodele.cppのコンストラクタの中では普通に代入処理が行われている。何がどうなっているのか。
conv1 = Conv2d(Conv2dOptions(in_channels, width, {1, 1}).stride(1).bias(false));
C++のドキュメントによるとnullptrはまさしくヌルポらしい。サンプルコードもまさしくポインタである。
int* p = nullptr;
しかしconv1がポインタなら↓こんな文はあり得ない。
Conv2d conv1 = nullptr;
conv1 = Conv2d(Conv2dOptions(in_channels, width, {1, 1}).stride(1).bias(false));
conv1がポインタなら↓こう書かないとダメなはずだ。
Conv2d conv1 = nullptr;
conv1 = new Conv2d(Conv2dOptions(in_channels, width, {1, 1}).stride(1).bias(false));
明日調べよう。疲れたので今日はここまで。
つづく