25
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【PyTorch】自作関数の勾配計算式(backward関数)の書き方①

Last updated at Posted at 2020-07-05

はじめに

PyTorchのbackward関数における勾配の計算式の書き方について解説していきます.

勾配計算まで自分で定義することは少ないですが,例えばC++でモデルを書き直したい,という時には自分で勾配計算式を定義しなくてはならないときがあります.
また,かなり特殊な層(例えばNeuralODEのODE blockなど)を,標準搭載の関数では実装出来ない,というときにもPyTorchの自動微分機能を使って自分で新たな関数を作る必要があります.

この記事は,自作関数を作る時に勾配計算の書き方で困らないように,その基礎をまとめた記事です.

1. 目標のカスタムtanh関数

PyTorchではtorch.tanh()が予め容易されていますが,ここでは以下のようなカスタムtanh関数を作りたいとします.
$$ y = 4 \tanh(\frac{x}{4}) $$

graph_image.png

もちろん,y = 4 * torch.tanh(x / 4)でこの関数は実装できますが,自動微分を実装したいので知らなかったことにします.

2.実装

先に完成形を見た方が見通しが良いと思うので,以下に実装を載せます.

custom_tanh.py
import torch

class custom_tanh(torch.autograd.Function ):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward( x )
        h = x / 4.0
        y = 4 * h.tanh()
        return y

    @staticmethod
    def backward(ctx, dL_dy):  # dL_dy = dL/dy
        x, = ctx.saved_tensors
        h = x / 4.0
        dy_dx = d_tanh( h )
        dL_dx = dL_dy * dy_dx

        return dL_dx

def d_tanh(x):
    return 1 / (x.cosh() ** 2)

3.ポイント

関数の名前はcustom_tanhとしました.以下に,勾配の計算方法以外で大事な点をまとめます.

  • 自作関数はtorch.autograd.Functionクラスを継承する必要があります.
  • メンバ関数にはforward()backward()を用意します.ここで大事なのはそれぞれの関数の第一引数ctxでなければならない,ということです.ctxはコンテキストのことであり,勾配計算に必要な順伝播時の情報を保持します.

4. backward()の書き方

今回の自作$y=custom \ tanh(x)$関数を微分して$\frac{dy}{dx}$を求めると以下になり,コードではd_tanh(x / 4)に対応します.
$$ \Bigl(4\tanh(\frac{x}{4})\Bigr)' = \frac{1}{\cosh^2{\frac{x}{4}}}$$

  • backward()関数のゴール(返り値)は入力の勾配です.よりわかりやすく言うと,この勾配とは$\frac{dL}{dx}$になります.

  • backward()関数は引数として出力の勾配$\frac{dL}{dy}$を受け取っています.ただし,ここでは$L$を目的関数,$y$を順伝播時の出力としています.よって,backward()内で行うべき処理は連鎖律を用いて,以下のように受け取った出力の勾配$\frac{dL}{dy}$に,関数出力を入力で微分した値$\frac{dy}{dx}$をかけて$\frac{dL}{dx}$を求めることになります.
    $$ \frac{dL}{dx} = \frac{dL}{dy} \frac{dy}{dx} $$

  • もしforward()関数の出力が複数でreturn y1, y2となっていたら,backward()関数の引数は以下のように,出力されたのと同じ順番で受け取ります.

class custom_tanh( torch.autograd.Function ):
    @staticmethod
    def forward(ctx, x):
        # (略)
        return y1, y2

    @staticmethod
    def backward(ctx, dL_dy1, dL_dy2):  # dL_dy = dL/dy
        ...
        return dL_dx

4.関数として使えることを確認

以下のようにすることで関数として使えます.注意すべきは(当然ですが)そのままでは使えず,一度custom_tanh.applyを実行する必要があることです.

if __name__=="__main__":
    # 入力
    X = torch.Tensor([0,4]).requires_grad_()
    # 関数化
    my_func = custom_tanh.apply
    # 計算を実行
    y = my_func(X)
    L = torch.sum(y*y)
    L.backward()
    print(X.grad)  # tensor([0.0000, 2.5588])

5. 発展:引数にテンソル以外の型の変数を受け取りたい

複雑な自作関数を作るとき,引数は常にテンソルとは限りません.

例えば,以下の例のように,forward処理時にInt型の変数image_sizeであったり,str型の変数activation_functionを受け取るかもしれません.これらはテンソルでないため,当然勾配を持ちませんが,backward関数の返り値はforward関数の引数の勾配を返さなくてはならないのでこれは問題です.

class custom_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, image_size=28, activation_function='tanh'):
        # (順伝播処理)
        return out

    @staticmethod
    def backward(ctx, grad_out):  # 引数はctxとforward関数の出力の勾配
        # (逆伝播処理)
        return ...  # 返り値はforwad関数の引数の勾配...?!

この場合,backward関数の返り値にNoneをあてることでエラーを回避できます.

@staticmethod
def backward(ctx, grad_out):  # 引数はctxとforward関数の出力の勾配
    # (逆伝播処理)
    return grad_tensor, None, None  # 返り値はforwad関数の引数の勾配,なければNone

参考

AUTOMATIC DIFFERENTIATION PACKAGE - TORCH.AUTOGRAD
実装したレポジトリ

25
11
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?