67
59

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchの関数本体ってどこで定義されてるの?逆伝搬は?~ATenの世界へのみちしるべ~

Last updated at Posted at 2020-07-12

はじめに

本稿はPyTorch中級者以上向けのニッチな内容ですが,どうしても計算方法の定義が見たいという方にはおすすめです.内容は主に3つです.

  • PythonからC++で書かれたATenライブラリへの橋渡し
  • (conv2dを例に挙げ,)ATen内で順伝搬関数が扱われているかを見る
  • 逆伝搬関数がいかに順伝搬関数と紐づいているかを見る

隠された実装

PyTorchの公式ドキュメントを見ながらプログラミングをし,時には[Source]ボタンを押してgithubの実装を確認したりしますよね.でも[Source]ボタンがないときありますよね?例えばtorch.nn.Conv2dには[Source]ボタンがありますが,torch.nn.functional.conv2dにはありません.  

・・・なんで同じConv2dなのに実装が見れないの?  

と思いつつ,PyTorchのgithubを確認してみます.torch.nn.Conv2dはtorch.nn.functional.conv2d を利用しているので,後者の定義があるはずの torch/nn/functional.py を覗いてみると以下の定義がありました.

conv2d = _add_docstr(torch.conv2d, r"""
# 略
"""
)

はぁ・・・なんとなく[Source]ボタンがない理由がわかりました.見ても意味がないからです.ちなみにやっていることは conv2d = torch.conv2d とほぼ一緒です.
じゃあtorch.conv2dの定義を見ればいいのかと思い探します.

で,どこにtorch.conv2dの定義あるの・・・?

Pythonにおける実装を探そう

実はPyTorchの本質的機能の多くはPython内で完結していません.そう簡単にtorch.conv2dの定義を探すことは容易ではありません.と言ってもPythonのSyntaxを破ることはできないのでtorch.conv2dも一応Pythonの中で定義されています. torch/__init__.py ファイル内に以下で定義されています.

for name in dir(_C._VariableFunctions): 
    if name.startswith('__'): 
        continue 
    globals()[name] = getattr(_C._VariableFunctions, name) 
    __all__.append(name) 

コードを見せられても意味が良くわかりませんね.conv2dという文字列も見当たらないし.少しだけ説明すると,_Cとはtorch/_C.soファイルのことであり,C++などの低級言語からコンパイルされたモジュールです.(コンパイル前のgithub上では_C.soは確認できないので注意してください.)その_Cの_VariableFunctionsの中にconv2dが含まれており,for文内部でそれ以外の関数も含めてtorch.~として使用できるようにしています.

要約すると,ここを見ていても仕方ないです.なぜならconv2dの本体はPythonではなくC++で実装されているからです.コンパイルされる前のソースがどこにあるのかを見つけるのが建設的というものでしょう.

ATenの世界

ATenとはA Tensor libraryのことであり,githubだとここにあります.ATenは基本的にC++とCUDAによって実装されており,PyTorchの心臓部とも言える重要な実装が行われています.torch.conv2dのような関数はPyTorch内部でnative functionと呼ばれ,ATen/native内で実装されています.ATen/native内を探検するみちしるべとしてnative_functions.yamlを見ることが有効です.このyamlファイルに記載されている関数は_C._VariableFunctionsに格納されますし,ATen/native内のいずれかのファイルに実装されているからです.実際に検索してみるとconv2dがこのyamlファイルの中に記載されています.

- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor

ということはATen/nativeのどこかに実装されてるはずです.それらしい名前のファイルをあたってみると,ATen/native/Convolution.cpp内にconv2dの定義を見つけることができます.

at::Tensor conv2d(
    const Tensor& input, const Tensor& weight, const Tensor& bias, 
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
  return at::convolution(input, weight, bias, stride, padding, dilation,                    
                         false, {{0, 0}}, groups); 
}

ここまででPython〜ATen(コンパイル前のソース)への橋渡しができました.

ATenの中で何が行われているのか?

じゃあconv2dで呼び出しているat::convolutionはどうなっているの?という疑問は当然湧いてきます.もう少し深入りしてみましょう.コードを辿っていって,GPUの並列計算を起動する命令(CUDA Launch)かcuDNNの関数が見れたら嬉しいですね.
at::convolutionは同じファイル内で定義されており,at::_convolutionを呼び出しています.at::_convolutionでは詳細な畳み込み手法の選択が行われているようです.もしGPUを使用していたらcuDNNを使用していると思うので,at::cudnn_convolutionをさらに見て行くこととしましょう.少し苦労しますが,ATen/native/cudnn/Conv.cppに見つけることが出来ます.芋づる式にcudnn_convolution→cudnn_convolution_forward→raw_cudnn_convolution_forward_out→raw_cudnn_convolution_forward_out_32bitと辿ることができます.最終的には以下のcuDNN関数までコードを辿ることができます.


   AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark).try_all(
     [&](const cudnnConvolutionFwdAlgoPerf_t &fwdAlgPerf){ 
       Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input); 
  
       // update convDesc mathType since cudnn 7.4+ now requires both algo + mathType to figure out 
       // whether to use Tensor core kernels or not 
       // See Note [behavior of cudnnFind and cudnnGet] 
       AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType)); 
  
       Constant one(dataType, 1); 
       Constant zero(dataType, 0); 
  
       AT_CUDNN_CHECK(cudnnConvolutionForward( 
         args.handle, 
         &one, args.idesc.desc(), input.data_ptr(), 
         args.wdesc.desc(), weight.data_ptr(), 
         args.cdesc.desc(), fwdAlgPerf.algo, workspace.data_ptr(), fwdAlgPerf.memory, 
         &zero, args.odesc.desc(), output.data_ptr())); 
       } 
   ); 

上のコードにおいてcudnnConvolutionForwardがメインで動いている関数です.この関数は非オープンソースのcuDNNの関数で,遡ることのできる限界なのでここまでとしておきましょう.

逆伝搬の設定

ATenまでコードを遡ってみた訳ですが,一つ大事なことを忘れています.

conv2dの逆伝搬ってどう設定してるの・・・?自動微分って事前に順伝搬と逆伝搬の関数を決めておくんだよね?でもconv2dに紐づいた逆伝搬関数をまだみてないぞ?

そう,まだ逆伝搬の定義と順伝搬との紐付けが一切出てきませんでした.PyTorchはいかにしてconv2dの逆伝搬を行うのでしょうか.実はtoolsというところに秘密が隠されています.toolsはPyTorchのBuildを行う上で重要なフォルダであり,tools/autogradでは自動微分に関して重要な関数が定義されています.ここには順伝搬と逆伝搬の関数の紐付けを決めるderivatives.yamlがあります.このyamlファイルを調べてみるとconv2dについては記載がありませんが,代わりに以下の設定が見つかります.

- name: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
  self, weight: "grad.defined() ? cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) : std::tuple<Tensor, Tensor>()"

どうやらconv2dの実装を遡っていく中で出てきたcudnn_convolutionについて逆伝搬が設定されているようです.この逆伝搬関数cudnn_convolution_backwardはATen/native/cudnn/Conv.cpp内で定義されています.ちなみにこの下には以下の逆伝搬の設定が書いてあります.

- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor)
  grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], Tensor(), grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)

逆伝搬関数(1次微分)のさらなる逆伝搬関数を設定することで2次微分を得ることができるみたいですね.

別題:線形方程式の例

最後にconv2d以外の例としてtorch.solveのATen内での実装を探してみましょう.torch.solveは$A,b$を入力として線形方程式$Ax=b$を解き$x$を返します.native_functions.yamlをまず見てみると,ちゃんと以下の記述があります.

- func: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) 
  use_c10_dispatcher: full 
  variants: function, method 

ちょっとわかりにくいですが,ATen/native/BatchLinearAlgebra.cppにsolveの定義があります.どうやらat::_solver_helperという関数を呼び出しています.この関数を見つければいいのですが・・・見つかりません?もう一度yamlを見てみると以下の記述が見つかります.

- func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
  use_c10_dispatcher: full 
  variants: function 
  dispatch: 
    CPU: _solve_helper_cpu 
    CUDA: _solve_helper_cuda 

実はこの記述によりCPUとGPUの場合の関数のスイッチングを行えます.at::_solver_helperという関数の実体は存在せず,CPUならば_solver_helper_cpuに,GPUならば_solver_helper_cudaに引き継がれます.CPUの場合を辿っていくと,ATen/native/BatchLinearAlgebra.cppに_solver_helper_cpuがあります.さらにapply_solveという関数が呼ばれ,lapackSolveというLAPACKというライブラリを用いた線形方程式のソルバーに引き継がれていきます.
では次に逆伝搬を調べてみましょう.derivatives.yamlには以下のように逆伝搬を設定してありました.

- name: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) 
  self: solve_backward_self(grad, self, A) 
  A: solve_backward_A(grad, self, A, solution) 

$A,b$の2種類に対する逆伝搬を別々に書くことができるようです.ではsolve_backward_selfとsolve_backward_Aの実装を探してみる・・と・・・?いくら探してもATen/native内に見つけることができません.実はこれらの関数はtools/autograd/templates/Functions.cpp内で以下のように実装されています.

pytorch >= 1.7 では torch/csrc/autograd/FunctionsManual.cppに移動したようです.

Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
  return std::get<0>(at::solve(grad, A.transpose(-2, -1))); 
} 

Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { 
  Tensor grad_self = solve_backward_self(grad, self, A); 
  if (self.ndimension() == 2 && A.ndimension() == 2) { 
    return -at::mm(grad_self, solution.transpose(-2, -1)); 
  } 
  return -at::matmul(grad_self, solution.transpose(-2, -1)); 
} 

どうやら逆伝搬の実装自体はATenだけではなくtools/autograd/templates内でも行うようです.棲み分けとしては,ATenでは複雑な逆伝搬を,templates内では容易な逆伝搬を行っているように見受けられます.なお,線形方程式の逆伝搬がなぜこのような式になっているかはここに解説してあります.

終わり

以上でPyTorchの内部を見てみる記事は終わりです.conv2dとsolveの内容で大概のパターンに対処できると思いますので,実装を確認したい処理があったらATenの中を探してみるのも良いでしょう.

67
59
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
67
59

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?