はじめに
PyTorchのbackward関数における勾配の計算式の書き方について解説していきます.
勾配計算まで自分で定義することは少ないですが,例えばC++でモデルを書き直したい,という時には自分で勾配計算式を定義しなくてはならないときがあります.
また,かなり特殊な層(例えばNeuralODEのODE blockなど)を,標準搭載の関数では実装出来ない,というときにもPyTorchの自動微分機能を使って自分で新たな関数を作る必要があります.
この記事は,自作関数を作る時に勾配計算の書き方で困らないように,その基礎をまとめた記事です.
1. 目標のカスタムtanh関数
PyTorchではtorch.tanh()
が予め容易されていますが,ここでは以下のようなカスタムtanh関数を作りたいとします.
$$ y = 4 \tanh(\frac{x}{4}) $$
もちろん,y = 4 * torch.tanh(x / 4)
でこの関数は実装できますが,自動微分を実装したいので知らなかったことにします.
2.実装
先に完成形を見た方が見通しが良いと思うので,以下に実装を載せます.
import torch
class custom_tanh(torch.autograd.Function ):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward( x )
h = x / 4.0
y = 4 * h.tanh()
return y
@staticmethod
def backward(ctx, dL_dy): # dL_dy = dL/dy
x, = ctx.saved_tensors
h = x / 4.0
dy_dx = d_tanh( h )
dL_dx = dL_dy * dy_dx
return dL_dx
def d_tanh(x):
return 1 / (x.cosh() ** 2)
3.ポイント
関数の名前はcustom_tanh
としました.以下に,勾配の計算方法以外で大事な点をまとめます.
- 自作関数は
torch.autograd.Function
クラスを継承する必要があります. - メンバ関数には
forward()
とbackward()
を用意します.ここで大事なのはそれぞれの関数の第一引数ctx
でなければならない,ということです.ctx
はコンテキストのことであり,勾配計算に必要な順伝播時の情報を保持します.
4. backward()の書き方
今回の自作$y=custom \ tanh(x)$関数を微分して$\frac{dy}{dx}$を求めると以下になり,コードではd_tanh(x / 4)
に対応します.
$$ \Bigl(4\tanh(\frac{x}{4})\Bigr)' = \frac{1}{\cosh^2{\frac{x}{4}}}$$
-
backward()
関数のゴール(返り値)は入力の勾配です.よりわかりやすく言うと,この勾配とは$\frac{dL}{dx}$になります. -
backward()
関数は引数として出力の勾配$\frac{dL}{dy}$を受け取っています.ただし,ここでは$L$を目的関数,$y$を順伝播時の出力としています.よって,backward()
内で行うべき処理は連鎖律を用いて,以下のように受け取った出力の勾配$\frac{dL}{dy}$に,関数出力を入力で微分した値$\frac{dy}{dx}$をかけて$\frac{dL}{dx}$を求めることになります.
$$ \frac{dL}{dx} = \frac{dL}{dy} \frac{dy}{dx} $$ -
もし
forward()
関数の出力が複数でreturn y1, y2
となっていたら,backward()
関数の引数は以下のように,出力されたのと同じ順番で受け取ります.
class custom_tanh( torch.autograd.Function ):
@staticmethod
def forward(ctx, x):
# (略)
return y1, y2
@staticmethod
def backward(ctx, dL_dy1, dL_dy2): # dL_dy = dL/dy
...
return dL_dx
4.関数として使えることを確認
以下のようにすることで関数として使えます.注意すべきは(当然ですが)そのままでは使えず,一度custom_tanh.apply
を実行する必要があることです.
if __name__=="__main__":
# 入力
X = torch.Tensor([0,4]).requires_grad_()
# 関数化
my_func = custom_tanh.apply
# 計算を実行
y = my_func(X)
L = torch.sum(y*y)
L.backward()
print(X.grad) # tensor([0.0000, 2.5588])
5. 発展:引数にテンソル以外の型の変数を受け取りたい
複雑な自作関数を作るとき,引数は常にテンソルとは限りません.
例えば,以下の例のように,forward処理時にInt型の変数image_size
であったり,str型の変数activation_function
を受け取るかもしれません.これらはテンソルでないため,当然勾配を持ちませんが,backward関数の返り値はforward関数の引数の勾配を返さなくてはならないのでこれは問題です.
class custom_func(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, image_size=28, activation_function='tanh'):
# (順伝播処理)
return out
@staticmethod
def backward(ctx, grad_out): # 引数はctxとforward関数の出力の勾配
# (逆伝播処理)
return ... # 返り値はforwad関数の引数の勾配...?!
この場合,backward関数の返り値にNone
をあてることでエラーを回避できます.
@staticmethod
def backward(ctx, grad_out): # 引数はctxとforward関数の出力の勾配
# (逆伝播処理)
return grad_tensor, None, None # 返り値はforwad関数の引数の勾配,なければNone
参考
AUTOMATIC DIFFERENTIATION PACKAGE - TORCH.AUTOGRAD
実装したレポジトリ