1
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

PyTorchのRustバインディングを使ってみる

はじめに

PyTorchには非公式ながらRustバインディングが存在しているものの情報も少ない状態です。
そこで、本記事では動作環境をDocker上に構築して公式のexamplesを試してみます。

また、比較のために「TensorFlowのRustバインディングを使ってみる」という記事も書いています。Rustでディープラーニングをしてみたいと考えている方は是非そちらもご参照ください。

環境構築手順

プロジェクトの作成

任意のディレクトリにcargoでプロジェクトを作成します。

cargo new tch-example
cd tch-example

依存関係を追加する

Cargo.tomlのdependenciesに下記を追加します。
anyhowはtchの動作には必要ないのですが、サンプルで使用しているものがあるのでここで追加しておきます。

Cargo.toml
[dependencies]
tch = "0.3.1"
anyhow = "1.0.38"

Dockerfileの作成

tch-example下にDockerfileを作成し、下記をコピーします。

Dockerfile
FROM rust:1.50

WORKDIR /home

# ビルド
COPY Cargo.toml Cargo.toml
COPY ./src ./src
RUN cargo build --release

Dockerイメージのビルド

前の手順で作成したDockerfileをビルドします。

docker build -t tch-example .

サンプル1:basics.rs

環境が正しく構築されたかを確認するサンプルです。
examplesに含まれるbasics.rsの内容をコピーしてtch-example/src/main.rsに上書き、下記のコマンドを実行します。

docker run --rm -v "$PWD"/src:/home/src tch-example cargo run --release

以下のような結果が出力されれば正しく動いているかと思います。

   Compiling tch-example v0.1.0 (/home)
    Finished release [optimized] target(s) in 0.82s
     Running `target/release/tch-example`
 3
 1
 4
 1
 5
[ CPUIntType{5} ]
 0.4325  1.6180  0.7033 -1.9058
-0.8730  0.5980  1.1362 -1.0001
-0.9286 -0.1000  0.8674 -0.8939
-0.7715 -0.9395  0.8709 -0.3828
-0.4363  1.2445  0.8405  0.9313
[ CPUFloatType{5,4} ]
 1.9325  3.1180  2.2033 -0.4058
 0.6270  2.0980  2.6362  0.4999
 0.5714  1.4000  2.3674  0.6061
 0.7285  0.5605  2.3709  1.1172
 1.0637  2.7445  2.3405  2.4313
[ CPUFloatType{5,4} ]
 2.9325  4.1180  3.2033  0.5942
 1.6270  3.0980  3.6362  1.4999
 1.5714  2.4000  3.3674  1.6061
 1.7285  1.5605  3.3709  2.1172
 2.0637  3.7445  3.3405  3.4313
[ CPUFloatType{5,4} ]
 43.1000
 44.1000
 45.1000
[ CPUFloatType{3} ]
[3] 44.099998474121094
42
5
Cuda available: false
Cudnn available: false

サンプル2:pretrained-models

事前学習済みモデルを使用するサンプルです。
動作のためには、ソースコードのほかに事前学習済みのモデルをダウンロードする必要があります。

ソースコードの準備

examplesに含まれるpretrained-models/main.rsの内容をコピーしてtch-example/src/main.rsに上書きします。

事前学習済みモデルのダウンロード

事前学習済みモデルはOCamelバインディングのリポジトリからダウンロードします。
tch-example下で下記のコマンドを実行します。

wget https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/resnet18.ot

任意の画像の準備

任意の画像をモデルと同じ場所に、image.jpgという名前で保存します。
この時点でディレクトリ構成は以下のようになっています。

tch-example
├── Cargo.toml
├── Dockerfile
├── image.jpg
├── resnet18.ot
└── src
    └── main.rs

実行方法

下記のコマンドを実行します。

docker run --rm \
  -v "$PWD"/src:/home/src \
  -v "$PWD"/resnet18.ot:/home/resnet18.ot \
  -v "$PWD"/image.jpg:/home/image.jpg \
  tch-example cargo run --release resnet18.ot image.jpg

与えた画像によって推論結果は異なりますが、以下のような結果が出力されれば正しく動いているかと思います。

   Compiling tch-example v0.1.0 (/home)
    Finished release [optimized] target(s) in 1.12s
     Running `target/release/tch-example resnet18.ot image.jpg`
Egyptian cat                                       23.01%
tabby, tabby cat                                   10.06%
lynx, catamount                                     7.58%
tiger cat                                           4.51%
hamper                                              4.23%

サンプル3:MNIST

最後に、チュートリアルの定番となっているMNISTの学習を行うサンプルを試します。

MNISTのデータの準備

MNISTのデータセットをダウンロードします。

mkdir data
cd data
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gzip -d *.gz

ソースコードの準備

examples/mnistにある4ファイルをtch-example/src/下に配置します。

この時点でディレクトリ構成は以下のようになっています。
(サンプル2で使用したモデルなども含んでいます。)

tch-example/
├── Cargo.toml
├── Dockerfile
├── data
│   ├── t10k-images-idx3-ubyte
│   ├── t10k-labels-idx1-ubyte
│   ├── train-images-idx3-ubyte
│   └── train-labels-idx1-ubyte
├── image.jpg
├── resnet18.ot
└── src
    ├── main.rs
    ├── mnist_conv.rs
    ├── mnist_linear.rs
    └── mnist_nn.rs

実行方法

下記のコマンドを実行します。
コマンドライン引数にlinearconvを付けると別の構成のモデルで学習を試すことが出来ます。

docker run --rm \
  -v "$PWD"/src:/home/src \
  -v "$PWD"/data:/home/data \
  tch-example cargo run --release

以下のような結果が出力されれば正しく動いているかと思います。

   Compiling tch-example v0.1.0 (/home)
    Finished release [optimized] target(s) in 1.05s
     Running `target/release/tch-example`
epoch:    1 train loss:  2.29681 test acc: 28.58%
epoch:    2 train loss:  2.22892 test acc: 44.74%
epoch:    3 train loss:  2.16134 test acc: 55.58%
epoch:    4 train loss:  2.09158 test acc: 61.93%
epoch:    5 train loss:  2.01911 test acc: 65.85%

~~ 中略 ~~

epoch:  195 train loss:  0.19896 test acc: 94.23%
epoch:  196 train loss:  0.19821 test acc: 94.24%
epoch:  197 train loss:  0.19746 test acc: 94.25%
epoch:  198 train loss:  0.19672 test acc: 94.25%
epoch:  199 train loss:  0.19598 test acc: 94.29%

まとめ

examplesが充実しており、TensorFlowのRustバインディングに比べて簡単に使用できるためRustだけで完結させることも可能かもしれません。

一方で、Pythonで学習したモデルの重みを使用した推論やGPUの使用方法については本記事では調査しておらず、検証する必要があります。
また、公式のRustバインディングではないので将来的にメンテナンスされるかも不安ではあります。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
1
Help us understand the problem. What are the problem?