LoginSignup
1
2

More than 1 year has passed since last update.

Pythonが嫌いなのでC++版のPytorchで画像認識をやってみる 11日目 PytochとLibtorchは違う・・・ torch::catにバグ

Last updated at Posted at 2022-09-10

PytochとLibtorchは違う

とりあえずexeを吐けるようになったのでトライ&エラーを繰り返している。
ある程度予測はしていたが、PytochとLibtorchは違うのがだんだんわかってきた。
ここ何週間か行ったことをメモる。

ModuleはTORCH_MODULE()マクロを使う

公式を見ればわかるのだがモジュールを継承する場合、class ~Implで定義し、TORCH_MODULE()マクロでラップ(?)する。
https://github.com/nuka137/pytorch-cpp-example/tree/6d82b0240af6cb33af015e30b01ee3f0fc3deec2/resnet/cpp
この通りやらなくても動くことは動くのだが、Libtorchのコーディングの癖みたいのが分ってきて、こういうのはきちんとやらないとワケの分からないエラーが出そうなのでやっておく。Libtorchはデータをとにかくラッピングしている。おそらくGPUとCPUをシームレスに使うには必要なテクニックなのだろう。

#誤り
class Mymodule :public torch::nn::Module
{
    torch::nn::Conv2d conv1 = nullptr;
    torch::nn::BatchNorm2d bn1 = nullptr;;
    torch::nn::ReLU relu = nullptr;
    torch::nn::Conv2d conv2 = nullptr;
    torch::nn::BatchNorm2d bn2 = nullptr;
    torch::nn::Sequential shortcut = nullptr;
public:
    Mymodule(int64_t inp, int64_t oup, int64_t stride);
    ~Mymodule();
    torch::Tensor forward(torch::Tensor x);
};

#正しい
class MymoduleImpl :public torch::nn::Module
{
    torch::nn::Conv2d conv1 = nullptr;
    torch::nn::BatchNorm2d bn1 = nullptr;;
    torch::nn::ReLU relu = nullptr;
    torch::nn::Conv2d conv2 = nullptr;
    torch::nn::BatchNorm2d bn2 = nullptr;
    torch::nn::Sequential shortcut = nullptr;
public:
    MymoduleImpl(int64_t inp, int64_t oup, int64_t stride);
    ~MymoduleImpl();
    torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(Mymodule);

torch::cat()がランタイムエラー

結論から言うと対処方法が無かった。Libtorchではtorch::cat()が使えない。未定義の例外が発生する。検索しまくって出てきた有用な情報は↓これだけ。

これによるとWindows版Libtorchでは直しようが無いとのこと。GPUドライバが原因と書いてあるが、当方の環境ではCPUモードでもメモリ例外が発生した。なんでこんなシンプルな機能の致命的エラーが放置されているのか分からないし、事態の深刻さにしては情報が少なすぎる。ソースからライブラリをmakeできる人には大した問題ではないのだろうか。とにかく対処方法が無いので自分で関数を作った。デカいライブラリをデバックするより作った方が早い。

catは行列(tensor)を横や縦に連結する関数。自分が欲しいのは行列を横に連結する機能なので、これに特化した関数を作った。
↓こんな機能である。

a={ 1, 2,
    3, 4 };
b={ 5, 6,
    7, 8 };
c=my_cat(a,b);
cの中身
{ 1, 2, 5, 6,
  3, 4, 7, 8 };

これをc++で↓。2次元のマトリクスにだけ対応しているのであしからず。

template <typename TTT >
vector<TTT > _catbase(vector<TTT> _v0, vector<TTT> _v1, int k0, int k1)
{
	register int i;
	int n = _v0.size();
	int m = _v0.size() + _v1.size();

	register int j0 = 0;
	register int j1 = 0;

	vector<TTT> _v2;
	_v2.resize(m);

	for (i = 0; i < m; )
	{
		int j0max = j0 + k0;
		//for (; j0 < j0max; )
		while( j0 < j0max)
		{
			_v2[i] = _v0[j0];
			i++;
			j0++;
		}

		int j1max = j1 + k1;
		//for (; j1 < j1max; )
		while ( j1 < j1max )
		{
			_v2[i] = _v1[j1];
			i++;
			j1++;
		}
	}
	return _v2;
}

template <typename TTT >
torch::Tensor _cat(torch::Tensor _t1, torch::Tensor _t2, int k1, int k2)
{
	torch::Tensor _ret;

	//一次元ベクトルにする
	_t1 = _t1.reshape({ -1 });
	_t2 = _t2.reshape({ -1 });

	//vectorクラスに格納
    vector<TTT> _v1(_t1.data_ptr<TTT>(), _t1.data_ptr<TTT>() + _t1.numel());
	vector<TTT> _v2(_t2.data_ptr<TTT>(), _t2.data_ptr<TTT>() + _t2.numel());
	vector<TTT> _v3;

    //結合
	_v3 = _catbase(_v1, _v2, k1, k2);

	//2次元の行列に戻す
    _ret = torch::tensor(torch::ArrayRef<TTT>(_v3)).reshape({ -1, k1 + k2 });
	return _ret;
}

//使い方 型指定忘れぬよう
torch::Tensor t1 = torch::rand({ 5000000, 2 }, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor t2 = torch::rand({ 5000000, 5 }, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor t3 = _cat1<float>( t1, t2 , 2, 5);

_t1と_t2を横につなげる。k1,k2はそれぞれ_t1、_t2の横幅(列の数)。関数の中でTensorのメンバ関数から引っ張り出せばk1とk2は省略できるのだけど、そんなに動的に変化するパラメータではないので引数にした。classのメンバ関数は便利だが、たかが引数を省略するくらいで、メカニズムの良く分からないメソッドを毎回呼び出すこともあるまいと思う昭和のおじさんである。ポインタで処理した方が早そうだけど、型が変わったときにいろいろ面倒そうなのでvectorクラスを使っている。

C++でtorch::catがただ使えないというのは何だか敗北感があったので、高速化で付加価値付けるべくカウンターにregisterを指定してみた。経験上同時に使えるregister変数は2つくらいまでだが、とりあえず3個のカウンタ変数をregister指定。registerは現行のC++では非推奨とのことだったが、intelのCPUでは効果があり、指定しない場合よりも指定した方が倍ほど速くなった。ryzenでは効果が無し。調子に名乗ってwhile文の中をインラインアセンブラにしてみたりしたがx64のコンパイラはインラインアセンブラに非対応ということで残念ながら不採用。

それにしてもこんなベーシックな機能にエラーが出るとすると、今後詰む可能性があるなあ。

Libtorchにはload_state_dictが無い _| ̄|○

無いのである。Torchにはニューラルネットワークのデータを保存する方法が複数あるが、Libtorchではload_state_dictとsave_state_dictがサポートされていない。手元にある学習済みデータは全てstate_dict形式なのでこれで詰んだか思ったが、対処方法があった。参考サイトは↓

pythonで作ったニューラルネットワークモデルで読み込みんで、旧形式で保存すればいいとのこと。要するにpythonで変換するわけだ。

変換は結構大変だった

参照先のコードはこうである。余計な事が書いておらず、かつ分かりやすいコード。こうありたいものである。

model = modelName(input_channel, output_channel).cuda()
model.load_state_dict(torch.load('./model.pth')) #  pytorchモデルを読む
model.eval()
example = torch.rand(1, 3, 256, 256)  #入力用に定義
device = torch.device("cuda")
mode = "cuda"
model = torch.jit.trace(model, example.to(device))
model.save("Net_h{}_w{}_{}.pt".format(input_height, input_width, mode))

まあ書いてある通りなのだが一度データを入れてネットワークを最適化する必要がある。最適化するコードがこの部分↓。
model = torch.jit.trace(model, example.to(device))

torch.jit.traceは最適化したいモデルのインスタンスとサンプル入力を引数に取る。引数が複数ある場合はタプルで指定と書いてあるがこれがうまくいかなかった。タプルでコードを書いても1つ目の要素しか認識ない。海外の人も困っているようで↓、名前付きタプルを使うとかいくつか解決策が出ているが、当方の環境ではうまくいかなかった。Python最低である。

結局引数を渡すのはあきらめて、変換したいネットワークデータのモデルのforward関数の中にサンプルデータを埋め込む方法で解決。こんな感じ↓。

    def forward(self,x):
        if(0): #オリジナルコード
            nrm_input1, nrm_input2, nrm_input3, nrm_input4 = x
        if(1): #コンパイル用  torch.jit.traceのタプル渡しが上手くいかないとき
            device = torch.device("cuda")
            mode = "cuda"
            nrm_input1 = torch.rand(4, 256, 32, 40).to(device)
            nrm_input2 = torch.rand(4, 256, 16, 20).to(device)
            nrm_input3 = torch.rand(4, 256,  8, 10).to(device)
            nrm_input4 = torch.rand(4, 256,  4,  5).to(device)
        if(0): #コンパイル用  torch.jit.traceのタプル渡しが上手くいかないとき
            device_cuda = torch.device("cpu")
            mode = "cpu"
            nrm_input1 = torch.rand(4, 256, 32, 40).to(device)
            nrm_input2 = torch.rand(4, 256, 16, 20).to(device)
            nrm_input3 = torch.rand(4, 256,  8, 10).to(device)
            nrm_input4 = torch.rand(4, 256,  4,  5).to(device)

cudaとcpuで別々に変換を行う。モジュールのデバイスモードとサンプルデータのデバイスモードが合致していないとエラーが出るのでいろいろ面倒臭い。if(0)は一時的にコードを無効にするときに、コメントアウトだと間違って削除してしまうことがあるので、小職が良く使う方法である。C++ならコンパイラで飛ばされるコードになるが、Pythonだとそのままメモリに展開されるので注意。

まあ、なんか力技だがpthファイルをptファイルに変換できた。のかなあ・・・

Moduleの派生クラスなどのデータセーブ、ロードでC2679エラー

下のような場合にエラーがコンパイルエラーC2679が出る。

struct example_module:torch::nn::Module 
{
	example_module() 
    {
		first_l = register_module("conv1", torch::nn::Conv2d(3, 3, 1));
	}
	torch::nn::Conv2d first_l = nullptr;
};

int main()
{
    example_module model = example_module();
    torch::save(model, "hoge.pt"); //<-C2679
    return 0;
}

こんなメッセージerror C2679: binary '<<' : no operator found which takes a right-hand operand of type・・・でデータロードや書き込みとどんな関係あるのか直感的に分かりにくい。対処方法はこれだけ↓

//誤
struct example_module:torch::nn::Module 
//正
struct example_module:public torch::nn::Module 

publicの指定が抜けていたので、演算子関数が必要なメンバ関数にアクセスできなかったわけだ。

つづく

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2