3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

libtorchでMNIST

Last updated at Posted at 2022-03-22

概要

前回、Visual Studio 2019でのlibtorchの環境構築を行ったのでその続き。
libtorchでMNISTを動かしてみる。本家にあったExampleを動かしてみて動かなかったので同じ思いをしている人のための記事。

Exampleの実行

とりあえず、本家のソースをそのままコピーしてきてコンパイルを通す。コンパイルは無事通るはず。

example-app.cpp
#include <torch/torch.h>

// Define a new Module.
struct Net : torch::nn::Module {
	Net() {
		// Construct and register two Linear submodules.
		fc1 = register_module( "fc1", torch::nn::Linear( 784, 64 ) );
		fc2 = register_module( "fc2", torch::nn::Linear( 64, 32 ) );
		fc3 = register_module( "fc3", torch::nn::Linear( 32, 10 ) );
	}

	// Implement the Net's algorithm.
	torch::Tensor forward( torch::Tensor x ) {
		// Use one of many tensor manipulation functions.
		x = torch::relu( fc1->forward( x.reshape( { x.size( 0 ), 784 } ) ) );
		x = torch::dropout( x, /*p=*/0.5, /*train=*/is_training() );
		x = torch::relu( fc2->forward( x ) );
		x = torch::log_softmax( fc3->forward( x ), /*dim=*/1 );
		return x;
	}

	// Use one of many "standard library" modules.
	torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};

int main() {
	// Create a new Net.
	auto net = std::make_shared<Net>();

	// Create a multi-threaded data loader for the MNIST dataset.
	auto data_loader = torch::data::make_data_loader(
		torch::data::datasets::MNIST( "./data" ).map(
			torch::data::transforms::Stack<>() ),
		/*batch_size=*/64 );

	// Instantiate an SGD optimization algorithm to update our Net's parameters.
	torch::optim::SGD optimizer( net->parameters(), /*lr=*/0.01 );

	for( size_t epoch = 1; epoch <= 10; ++epoch ) {
		size_t batch_index = 0;
		// Iterate the data loader to yield batches from the dataset.
		for( auto& batch : *data_loader ) {
			// Reset gradients.
			optimizer.zero_grad();
			// Execute the model on the input data.
			torch::Tensor prediction = net->forward( batch.data );
			// Compute a loss value to judge the prediction of our model.
			torch::Tensor loss = torch::nll_loss( prediction, batch.target );
			// Compute gradients of the loss w.r.t. the parameters of our model.
			loss.backward();
			// Update the parameters based on the calculated gradients.
			optimizer.step();
			// Output the loss and checkpoint every 100 batches.
			if( ++batch_index % 100 == 0 ) {
				std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
					<< " | Loss: " << loss.item<float>() << std::endl;
				// Serialize your model periodically as a checkpoint.
				torch::save( net, "net.pt" );
			}
		}
	}
}

何も考えずにそのまま動かすとabortが出る。

問題点の修正1 - MNISTファイルのダウンロード

Exampleのソースコードだけ見ていても原因はわからないので、libtorchのソースコードを調べてみると原因がわかった。
PyTorchだと自動的にMNISTのデータをダウンロードしてくれるみたいだが、libtorchでは自動的にダウンロードする機能はない模様。したがって、実行する前にデータをダウンロードする必要がある。
MNIST DatabaseのWebページで配布されているのでダウンロードする。ダウンロードする必要があるのは以下の4点。

  • train-images-idx3-ubyte.gz : トレーニング用のイメージ
  • train-labels-idx1-ubyte.gz : トレーニング用の答え
  • t10k-images-idx3-ubyte.gz : テスト用のイメージ
  • t10k-labels-idx1-ubyte.gz : テスト用の答え

これらを解凍すると次のようなファイルができる。(7zipの場合)

解凍前のファイル名 解凍後のファイル名
train-images-idx3-ubyte.gz train-images.idx3-ubyte
train-labels-idx1-ubyte.gz train-labels.idx1-ubyte
t10k-images-idx3-ubyte.gz t10k-images.idx3-ubyte
t10k-labels-idx1-ubyte.gz t10k-labels.idx1-ubyte

dataフォルダにいれて実行する。dataフォルダは実行フォルダに作成する。
再度実行すると・・・ 動きません

問題点の修正2 - MNISTファイル名の変更

libtorchのソースコードを見るとMNISTファイル名と解凍したあとのファイル名が異なる。

ソースに書かれているファイル名 解凍後のファイル名
train-images-idx3-ubyte train-images.idx3-ubyte
train-labels-idx1-ubyte train-labels.idx1-ubyte
t10k-images-idx3-ubyte t10k-images.idx3-ubyte
t10k-labels-idx1-ubyte t10k-labels.idx1-ubyte

なぜか、ハイフンのところがドットに置き換わっているので、修正する必要がある。
これで実行するとようやく無事動作する。

テストデータでの実行

このExampleソースコードを見ていると気がつくと思うが、学習はしているもののテストデータでテストしていないという問題点がある。テストデータで実行するサンプルソースを探してみたが見つからなかった。仕方ないので自分で記述する。

example-app.cpp
#include <torch/torch.h>
#include <iostream>

struct Net : torch::nn::Module {
	Net() {
		// Construct and register two Linear submodules.
		fc1 = register_module( "fc1", torch::nn::Linear( 784, 64 ) );
		fc2 = register_module( "fc2", torch::nn::Linear( 64, 32 ) );
		fc3 = register_module( "fc3", torch::nn::Linear( 32, 10 ) );
	}

	// Implement the Net's algorithm.
	torch::Tensor forward( torch::Tensor x ) {
		// Use one of many tensor manipulation functions.
		x = torch::relu( fc1->forward( x.reshape( { x.size( 0 ), 784 } ) ) );
		x = torch::dropout( x, /*p=*/0.5, /*train=*/is_training() );
		x = torch::relu( fc2->forward( x ) );
		x = torch::log_softmax( fc3->forward( x ), /*dim=*/1 );
		return x;
	}

	// Use one of many "standard library" modules.
	torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};

int main() {
	auto net = std::make_shared<Net>();
	torch::load( net, "../net.pt" );

	auto data_test = torch::data::datasets::MNIST( "../data", torch::data::datasets::MNIST::Mode::kTest )
		.map( torch::data::transforms::Stack<>() );
	auto data_loader_test = torch::data::make_data_loader( data_test );

	float loss = 0;
	int correct = 0;
	int size = 0;
	for( auto& batch : *data_loader_test ) {
		torch::Tensor prediction = net->forward( batch.data );
		loss += torch::nll_loss( prediction, batch.target, {}, torch::Reduction::Sum ).item<float>();
		correct += prediction.argmax( 1 ).eq( batch.target ).sum().item<int>();
		size++;
	}
	std::cout << "Size: " << size << " Average loss: " << loss / (double)size << " Accuracy: " << (double)correct / (double)size << std::endl;
}

実行してみたところ、Accuracy(正答率)は0.8965だった。
net.ptとdataの場所が先程のソースと合っていないが環境に合わせて修正してほしい。
気に食わないのが、データ構造を記したNetをトレーニング側とテスト側の両方に記載しなければならないところ。テスト側に記載しないで良い方法が見つかればまた記事に載せる。

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?