はじめに
この記事で得られる知識
PyTorchモデルをC++を使って高速化させる方法を紹介します.さらに発展的な処理をC++で行うために知っておきたい知識(C++内でのテンソルの生成や型変換)も紹介します.
PyTorchは深層学習だけでなく,その自動微分機能を用いてある種の最適化問題を解く時にも使うことができます.しかし,複雑なカスタムモデルを作ったとき,そのままでは非常に遅くなってしまうことがあります.例えば,モデル内部の実装にfor文が二重三重にネストしてしまっているときなどです.マルチプロセスを使う・そもそもの実装を工夫するなどの方法もありますが,ここでは実行言語を一部C++で置き換えることでモデルの高速化する方法を紹介します.
PyTorchでC++を使う,とは?
PyTorchは基本的にPython用のライブラリですが,C++でも使えるPyTorch C++ API(LibTorchと呼ばれてます)も配布されています.これにより,C++のみで訓練・推論スクリプトを書けるほか,Pythonコードの一部を高速なC++で置き換えたりJIT実行したりすることができます.
PyTorch C++ APIの公式ドキュメントではざっくり以下の5つに分類できるとされています.
- ATen: 以下のすべての基礎になっているPyTorch C++ APIライブラリ
- Autograd: 自動微分機能のあるATenの強化版
- C++ Frontend: 簡単にC++で訓練・推論するためのもの
- TorchScript: JIT方式compilerとインタプリタを使うためのインターフェース
- C++ Extensions: Pythonスクリプトの一部をC++/CUDAで置き換えるための方法 ← 今回のターゲット
C++ エクステンションが素晴らしいのは,簡単なPythonでモデルの大部分を書きつつ,処理の遅い一部分のみを高速なC++で書く,といった両言語のいいとこ取りができる点です.
準備:PyTorchでの自作関数の作り方
置き換えたい重い処理を一つの自作関数と見なすことが第一歩となります.その関数をC++で置き換えることでモデルを高速化していきます.
しかし,モデルをC++で書き替える前に,そもそも自作関数を作るためにはどうすればよいのかを知る必要があります.EXTENDING PYTORCHにて,自作関数の書き方が説明されています.非常に大雑把に書いてしまえば,以下のようにtorch.autograd.Function
を継承したクラスを作ることになります.
class custom_func(torch.autograd.Function):
@staticmethod
def forward(...):
# (順伝播処理)
return ...
@staticmethod
def backward(...):
# (逆伝播処理)
return ...
forward関数のみならずbackward関数も自分で書く必要がありますので,作りたい関数の微分時の勾配計算を手計算で求められる必要があります.backward関数での勾配計算の仕方は
この(順伝播処理),(逆伝播処理)部分はPythonでそのまま記述することも可能ですが,ここでC++関数を呼び出すことで処理を高速化するのがC++エクステンションになります.
本題:C++ エクステンションの利用方法
それでは実際にC++エクステンションの使い方を見ていきましょう.C++エクステンションの公式ドキュメントでLSTMならぬLLTMというモデルがC++エクステンションを用いて実装されていますので,まずはそちらを確認するのがよいでしょう.本記事では最小限度の構成が分かるように,極めて簡潔な「3つの値を足す」モデルを例として扱って解説していきます.環境はUbuntu 20.04 LTSです.
C++で関数を書く
3つのテンソルを受け取ってその和を返す処理をC++で書くと,以下のようになります.
#include <torch/extension.h>
#include <iostream>
#include <vector>
// FORWARD
std::vector<at::Tensor> add3numbers_forward(
torch::Tensor x,
torch::Tensor y,
torch::Tensor z){
auto sum = x+y+z;
return {sum}; // リストに入れて返してる
}
// BACKWARD
std::vector<torch::Tensor> add3numbers_backward(
torch::Tensor grad_sum /* = dL/dsum */ ){
/*
return dL/dx, dL/dy, dL/dz をすれば良い
dsum/dx = d(x + y + z) / dx = 1.0
よって,dL/dx = dL/dsum * dsum/dx = dL/dsum
つまり受け取ったgrad_sumをそのまま返せば良い
y, zに関しても同様
*/
return {grad_sum, grad_sum, grad_sum};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &add3numbers_forward, "add3numbers forward");
m.def("backward", &add3numbers_backward, "add3numbers backward");
}
このコードはforward処理定義部分,backward定義部分,そして末尾のPython言語へのバインディング処理部分から構成されます.
forward関数の書き方について,テンソルの型はC++ではtorch::Tensor
です.後は実現したい処理をC++の記法に則って書くのみです.
backward関数は普段書く機会がないので少し難しいですが,backward関数の引数はforward関数の返り値の勾配,backward関数の返り値はforward関数の引数の勾配だと思うと良いでしょう(厳密には,なにが引数になるのか,何を返り値として返せば良いのか等は後述のPythonからの呼び出し方に依ります).
最後のPythonへのbinding記述部分はおまじないのようなものなのでadd3numbers
という部分以外はいじらないようにしましょう.
C++のコードをビルドする
以下のスクリプトをadd3numbers.cpp
と同じディレクトリにおいてください.
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='add3numbers_cpp',
ext_modules=[cpp_extension.CppExtension('add3numbers_cpp', ['add3numbers.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})
こちらもadd3numbers
という文字部分以外はおまじないとして考えるくらいでいいでしょう.
ディレクトリ構造は以下のようになっているはずです.
add3numbers/
add3numbers.cpp
setup.py
add3numbers
ディレクトリで以下のコマンドを実行します
python setup.py install
長いビルドメッセージが流れた後,Finished processing dependencies for ...
と表示されたら完了です.これでPythonからC++で記述したadd3numbers_cpp
が呼び出せるようになりました.
In [1]: import add3numbers_cpp
In [2]: add3numbers_cpp.forward
Out[2]: <function add3numbers_cpp.PyCapsule.forward>
Pythonの自作関数からC++関数を呼び出す
前述の自作関数をmain.py
に記述し,ビルドしたadd3numbers_cpp
を呼び出します.
import torch
import torch.nn as nn
import add3numbers_cpp
class Add3NumbersFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, z):
output = add3numbers_cpp.forward(x, y, z) # ここでC++のforward関数を呼び出す
print('output',output)
ctx.save_for_backward(*output)
return output[0]
@staticmethod
def backward(ctx, grad_sum):
print('grad_sum',grad_sum)
d_output = add3numbers_cpp.backward(grad_sum) # ここでC++のbackward関数を呼び出す
dx, dy, dz = d_output
return dx, dy, dz
class Add3Numbers(nn.Module):
def __init__(self):
super(Add3Numbers, self).__init__()
pass
def forward(self, x, y, z):
return Add3NumbersFunction.apply(x, y, z)
class Add3Numbers(nn.Module):
def __init__(self):
super(Add3Numbers, self).__init__()
pass
def forward(self, x, y, z):
return Add3NumbersFunction.apply(x, y, z)
if __name__=='__main__':
model = Add3Numbers()
x = torch.Tensor([1]).requires_grad_()
y = torch.Tensor([2])
z = torch.Tensor([3])
print("x.grad: ", x.grad) # None
# FORWARDが正しくできるか確認
out = model(x, y, z)
print('model(x, y, z): ', out)
# BACKWARDが正しくできるか確認
loss = torch.sum(out)
loss.backward()
print(x.grad) # None --> tensor([1.])に変わっている!!
ビルドしたC++モジュールはAdd3NumbersFunction
クラスのforward関数,backward関数内で呼び出されています.Add3NumbersFunction
クラスの詳しい書き方はここでは割愛しますので,再掲になりますが,
- EXTENDING PYTORCH
- 【PyTorch】自作関数の勾配計算式(backward関数)の書き方①,
-
【PyTorch】自作関数の勾配計算式(backward関数)の書き方② 〜多変数出力の場合〜
を参考にしてください.
以上で最小構成のC++エクステンションモジュールの実装は終了です.
発展:より複雑なモジュールを作りたい時
今まででC++エクステンションの最低限の使い方を見てきましたが,実際にC++エクステンションの利用を考えている人はif文やfor文,さらにはPyTorchで提供されているような便利関数などを多用された複雑なカスタムモデルを作りたい人が多いのではないでしょうか?ここでは普段のPython likeな実装をC++ APIでするために知っておきたいことを紹介していきます.
テンソルの生成
C++でのテンソル生成の公式ドキュメントはこちら(Tensor Creation API)になります.
ここではよく使う関数のPythonでの記法とPyTorchでの記法の対応を以下の表にまとめました.
Python | C++ |
---|---|
torch.Tensor([0.]) |
torch::tensor({0.}) |
torch.zeros([2,3]) |
torch::zeros({2,3}) |
torch.zeros_like(x) |
torch::zeros_like(x) |
torch.ones([2,3]) |
torch::ones({2,3}) |
torch.ones_like(x) |
torch::ones_like({x}) |
torch.randn([3,4]) |
torch::randn({3,4}) |
torch.arange(1,9) |
torch::arange(1,9) |
基本的にPythonとほぼ同じ書き方ができます.地味に注意すべき点として,配列からテンソルを生成するtorch::tensor
(全部小文字)とテンソルの型を表すtorch::Tensor
(TensorのTが大文字)とを混同しないように注意しましょう.
また,一般にテンソルデータがCPUもしくはGPUにあるのなら,それをからテンソルを作ることもできます.
float data[] = { 1, 2, 3,
4, 5, 6 };
torch::Tensor f = torch::from_blob(data, {2, 3});
形のreshapeは以下です.
x = x.reshape({n, m})
テンソルの形を取りたいときは.sizes()
を使います.
c10::ArrayRef<long int> shape;
shape = x.sizes();
std::cout << shape << std::endl; // [n, m]
テンソルサイズのi番目の配列要素にアクセスしたいときはx.sizes()[i]
かx.size(i)
とします.
その他,提供されているAPIライブラリ
C++でもtorch::sigmoid()
やtorch::elu()
, torch::tanh()
のような数学関数が用意されています.
C++で使えるPyTorchの便利関数は
-
Namespace torch - Functions ←
torch::
名前空間で利用できる関数一覧 -
Library API - Functions ←
at::
名前空間で利用できる関数一覧
で一覧を確認できます.
ただし,これらの公式ページはtorch::sigmoid()
がまだ載っていないなど,完全にはまだ整備されていない(?)ようです.なので,使いたい関数が本当にあるのか,手元で簡単にいじれるc++のスクリプトを書いて実行・デバッグをして確かめると良いでしょう.INSTALLING C++ DISTRIBUTION OF PYTORCHが参考になります.ややこしく見えますが,手順通り進めればmake
コマンドを実行するだけでPyTorch C++スクリプトをビルド・実行できます.
ちなみにtorch::
名前空間もat::
名前空間も似た関数を提供していますが,こちらで述べられているように基本的には自動微分機能があるtorch::
を使うのが推奨されているようです.
Indexing
公式の説明ページはTensor Indexing APIです.
テンソルの一部の要素にアクセスするためのindexingはC++でもPythonと同様に可能で,None
/...
/整数値/ブーリアン/スライスを使って任意の要素にアクセスできます.
Pythonではindexingを[]
で行うのに対し,C++では
-
torch::Tensor::index()
(値の取り出し時) -
torch:Tensor::index_put_()
(値の代入時)
でIndexingを行うことができます(簡単な操作なら[]
も使える).
以下,簡単な具体例を表にします.ただし,C++の方ではusing namespace torch::indexing
が必要です.
Python | C++ |
---|---|
tensor[1,2] |
tensor.index({1, 2}) or tensor[1][2]
|
tensor[True, False] |
tensor.index({true, false}) |
tensor[1::2] |
tensor.index({Slice(1, None, 2)}) , Sliceの引数はstart, stop, step |
tensor[1,2] = 1 |
tensor.index_put_({1, 2}, 1) or tensor[1][2] = 1
|
tensor[True, False]=1 |
tensor.index_put_({true, false}, 1) |
tensor[1::2] = 1 |
tensor.index_put_({Slice(1, None, 2)}, 1) |
変数の型変換
テンソルの型変換方法はこちらで紹介されています.テンソルにはfloat型,int型などあり,それらに変換するには.to()
メソッドを用います.
#include <torch/torch.h>
#include <iostream>
int main() {
torch::Tensor tensor = torch::rand({2, 3});
auto x = tensor.to(torch::kInt); // ここで整数型テンソルに変換
std::cout << tensor << std::endl; // [ CPUFloatType{2,3} ]とprintされる
std::cout << x << std::endl; // [ CPUIntType{2,3} ]とprintされる
}
また,不等号演算子(>
, ==
, <
)を用いてテンソル同士の比較を取れば,ブーリアン型のテンソルである[ CPUBoolType{} ]
のテンソルが返されます.さらに,比較を取る2つのテンソルの大きさが異なるとき,ブロードキャスト可能であればブロードキャストされます.
torch::Tensor a = torch::tensor(0); // shape : {}
torch::Tensor b = torch::tensor({-1, 0, 1}); // shape : {3}
std::cout << (a > b) << std::endl; // 1 0 0 [ CPUBoolType{3} ]とprintされる
ただし注意が必要なのは,ここで生成されるブーリアン型テンソルの変数はif文などの条件式で使うブーリアン変数としては使えないということです.
torch::Tensor a = torch::tensor(0); // shape : {}
torch::Tensor b = torch::tensor(1); // shape : {}
auto a_gt_b = (a > b) // a is greater than b の略
if (a_gt_b){ // a_gt_bというテンソルをbool型に変換できないとエラーがでる
...
}
出力の結果によって処理を変えたいことは多々あるので,比較の結果からブーリアンを取れないのは不便です.少し力技になりますが,こちらで紹介されている通りBoolTensor型(値は0/1のみ)をfloat型に変換することでif文の条件式として使用可能にすることができます.C++では値が1なら真として扱われ,0なら偽となります.
auto a_lt_b = (a < b); // a is less than b の略
if (a_lt_b.item().toBFloat16()){ // float型の1.に変換される
std::cout << (a_lt_b.item().toBFloat16()) << std::endl; // 1 がprintされる
}
自作モジュールの引数にテンソル以外の型の変数を受け取りたい
これは正確にはPythonのautogradに関するものですが.一応紹介します.複雑な自作関数を作るとき,引数は常にテンソルとは限りません.
例えば,以下の例のように,forward処理時にInt型の変数image_size
であったり,str型の変数activation_function
を受け取るかもしれません.これらはテンソルでないため,当然勾配を持ちませんが,backward関数の返り値はforward関数の引数の勾配を返さなくてはならないのでこれは問題です.
class custom_func(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, image_size=28, activation_function='tanh'):
# (順伝播処理)
return out
@staticmethod
def backward(ctx, grad_out): # 引数はctxとforward関数の出力の勾配
# (逆伝播処理)
return ... # 返り値はforwad関数の引数の勾配...?!
この場合,backward関数の返り値にNone
をあてることでエラーを回避できます.
@staticmethod
def backward(ctx, grad_out): # 引数はctxとforward関数の出力の勾配
# (逆伝播処理)
return grad_tensor, None, None # 返り値はforwad関数の引数の勾配,なければNone
終わりに
PyTorch C++エクステンションをある程度自由に使いこなせるようになるために知っておくべきことをまとめました.
注意点として,複雑なモデルになればなるほど,backward関数を書くのは難しくなってきますので,実際にC++エクステンションを使う前に,backward関数が手計算で求められるか確認するとよいでしょう.
また,C++エクステンションを使うからといって必ずしも速度が大幅に上昇するとは限りません.有名な話ですが,結局PyTorchの行列演算は内部ではC++でビルドされたモジュールが動いているので,単純な処理をする限りはPythonで書いても速度はほとんど変わりません.ただ,もしもif文やネストしたfor文などPythonが苦手とする処理が多くあるのなら,C++エクステンションを用いることで大幅な速度向上が見込まれます.