はじめに
前回の【PyTorch】自作関数の勾配計算式(backward関数)の書き方①の続きです.前記事では1次元入力1次元出力の関数の勾配計算式(backward()
関数)の書き方をまとめました.本記事ではさらに拡張して多変数入出力の場合に,どのように勾配計算を定義するかについてまとめます.
勾配計算の書き方の復習(1次元入出力の場合)
- 自作関数はtorch.autograd.Functionクラスを継承する必要があります.
- メンバ関数にはforward()とbackward()を用意します.ここで大事なのはそれぞれの関数の第一引数ctxでなければならない,ということです.ctxはコンテキストのことであり,勾配計算に必要な順伝播時の情報を保持します.
-
backward()
関数のゴール(返り値)は入力の勾配です.よりわかりやすく言うと,この勾配とは$\frac{dL}{dx}$になります. -
backward()
関数は引数として出力の勾配$\frac{dL}{dy}$を受け取っています.ただし,ここでは$L$を目的関数,$y$を順伝播時の出力としています.よって,backward()
内で行うべき処理は連鎖律を用いて,以下のように受け取った出力の勾配$\frac{dL}{dy}$に,関数出力を入力で微分した値$\frac{dy}{dx}$をかけて$\frac{dL}{dx}$を求めることになります.
$$ \frac{dL}{dx} = \frac{dL}{dy} \frac{dy}{dx} $$
多変数入出力関数の定義
本記事では以下のような関数を考えます.
def F(x1, x2, x3)
y1 = x1 + x2
y2 = x2 + x3
y3 = x3 + x1
return y1, y2, y3
L = y1 * y2 * y3 # 目的関数
よく見ると,この関数は一つの入力が複数の出力に影響を与えています.例えば,x1
という入力はy1
, y3
両方の出力の計算に用いられています.こうなると,入力の勾配$\frac{dL}{dx}$を計算したくても,1次元の場合のように単純に$\frac{dL}{dy} \frac{dy}{dx}$とするだけではまずそうです.この場合,どう計算するのが正解なのでしょうか?
多変数出力関数における連鎖律(Chain Rule)
数学的に多変数関数の微分は連鎖律を用いて以下のように計算することができます.
$$
f = f(x(u, v), y(u, v))の時, \
\frac{\partial f}{\partial u} = \frac{\partial f}{\partial x}\frac{\partial x}{\partial u} + \frac{\partial f}{\partial y}\frac{\partial y}{\partial u}
$$
これは変数の数がいくつになっても同じように計算でき,上の例だと
$$
\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial x_1} + \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial x_1} + \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial x_1}
$$
と計算できます.以上に基づいて勾配計算を実装したものが以下になります.
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, x1, x2, x3):
ctx.save_for_backward(x1, x2, x3) # 今回は不要
y1 = x1 + x2
y2 = x2 + x3
y3 = x3 + x1
return y1, y2, y3
@staticmethod
def backward(ctx, dy1, dy2, dy3):
x1, x2, x3 = ctx.saved_tensors # 今回は不要
print("dy: ",dy1, dy2, dy3)
dx1 = dy1*1 + dy2*0 + dy3*1
dx2 = dy1*1 + dy2*1 + dy3*0
dx3 = dy1*0 + dy2*1 + dy3*1
return dx1, dx2, dx3
確認
手計算で求めた勾配の値と計算値が一致するか確かめましょう.いま目的関数$L$は$L=(x_1 + x_2)(x_2 + x_3)(x_3 + x_1)$です.これを$x_1, x_2, x_3$で微分すると以下のようになります.
$$
\frac{\partial L}{\partial x_1} = (x_2 + x_3)(2x_1 + x_2 + x_3) \
\frac{\partial L}{\partial x_2} = (x_3 + x_1)(2x_2 + x_3 + x_1) \
\frac{\partial L}{\partial x_3} = (x_1 + x_2)(2x_3 + x_1 + x_2)
$$
ここで,$x_1 = 0, x_2=1, x_3 = 2$を代入すると
$$
\frac{\partial L}{\partial x_1} = 9,
\frac{\partial L}{\partial x_2} = 8,
\frac{\partial L}{\partial x_3} = 5
$$
となります.
以下のコードで確かめます.
if __name__=="__main__":
x1 = torch.Tensor([0]).requires_grad_()
x2 = torch.Tensor([1]).requires_grad_()
x3 = torch.Tensor([2]).requires_grad_()
my_func = Func.apply
y1, y2, y3 = my_func(x1, x2, x3)
L = y1 * y2 * y3
print("L: ",L)
L.backward()
print("x1 grad: ", x1.grad) # 9
print("x2 grad: ", x2.grad) # 8
print("x3 grad: ", x3.grad) # 5
めでたく一致しました.