Edited at

PyTorchで高階偏微分係数


はじめに

この記事ではPyTorchを使って高階の偏微分を求める方法を説明しています。


高階ではない微分係数

例えば、関数$f(x)=x^3$について、$\left. \frac{df(x)}{dx} \right|_{x=4}$を求めたいとします。

>>> import torch

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> f.backward()
>>> print(x.grad)
tensor([48.])

backward()を呼べばいいだけで、これは皆さん普通に実行していることです。

(以下、import torchは省略します。)


1つの変数だけを扱う場合


2階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における2階の微分係数$\left. \frac{d^2f(x)}{dx^2} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)

>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([48.], grad_fn=<ThMulBackward>),)
>>> g.backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'tuple' object has no attribute 'backward'
>>> g[0].backward()
>>> x.grad
tensor([24.])

torch.autograd.grad()を呼ぶときにcreate_graph=Trueとしているのがポイントです。

こうすると、微分係数(上の場合は48)だけでなく、

$f$の$x$に関する微分について計算グラフを作って、それも返してくれます。

すると、その計算グラフを使うことで、2階の微分係数が計算できるようになります。

上の例で、gと入力したとき

(tensor([48.], grad_fn=<ThMulBackward>),)

と表示されています。

48は1階の微分係数です。create_graph=Trueと設定しなかったら、これしか返してきません。実際、

>>> x = torch.tensor([4.], requires_grad=True)

>>> f = x ** 3
>>> g = torch.autograd.grad(f, x)
>>> g
(tensor([48.]),)

となります。先ほどの例にあるgrad_fn=<ThMulBackward>というあやしい(?)表示が、

計算グラフも一緒に作ってくれている証拠です。

また、わざとエラーを出してみました:laughing:

今は、$x$というひとつの変数についてしか微分していないのですが、

torch.autograd.grad()では、複数の変数の各々で偏微分する状況を扱うのが基本です。

そのため、返って来るのがいつでもtupleなんですね。

ひとつの変数しか考慮していなくても、要素数がひとつだけのtupleが帰ってきます。

そのため、

>>> g[0].backward()

と、tupleの最初の要素を使う旨、添え字0で指定しないと、エラーになります。


3階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における3階の微分係数$\left. \frac{d^3f(x)}{dx^3} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)

>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> h = torch.autograd.grad(g, x, create_graph=True)
>>> h
(tensor([24.], grad_fn=<ThMulBackward>),)
>>> h[0].backward()
>>> x.grad
tensor([6.])

最初にtorch.autograd.grad()を呼んで戻ってきた計算グラフをそのまま、

次のtorch.autograd.grad()の呼び出しで使っています。

こうすると、この2回目の呼び出しは、2階の微分の計算グラフを返してきます。

それについてbackward()すれば、3階の微分係数を計算できます。

答えは6.と表示されています。これは$x^3$を3回$x$で微分すると定数$6$になるためです。


2個以上の変数を扱う場合


2階の偏微分係数

関数$f(x,y)=(x+2w)^3$について、まず、$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)

>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])

$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3} = 60$、そして

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3} = 120$です。

では、次に、$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial y^2} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)

>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, y, create_graph=True)
>>> g
(tensor([600.], grad_fn=<MulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([120.])
>>> y.grad
tensor([240.])

当然ですが、$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$と

$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$は、同じ値120になります。

上のふたつの作業をまとめて実行しようとすると・・・

>>> x = torch.tensor([4.], requires_grad=True)

>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])
>>> g[1].backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g[1].backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

と、エラーが出てしまいました。

それに関して偏微分を求めたい2つの変数を、(x,y)とtupleにして

torch.autograd.grad()に渡すと、

>>> g

(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))

という箇所にあるように、それぞれの変数に関する1階の微分係数をtupleにして返してくれます。

しかし、g[0].backward()のあとに続けてg[1].backward()を実行したため、エラーになりました。

上の例では

>>> g = torch.autograd.grad(f, (x, y), create_graph=True)

と、もう一度torch.autograd.grad()を呼んでいます。

(注:x.grad.zero_()そしてy.grad.zero_()と、ゼロに初期化する処理は書いていないため、

x.grady.gradも、g[0].backward()g[1].backward()とでの値を

足したものになってしまっています。)

しかし、上のように、torch.autograd.grad()が返したきた勾配の

別々の要素をそれぞれ微分するといった使い方はしないでしょう。

>>> x = torch.tensor([4.], requires_grad=True)

>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> h = g[0] + g[1]
>>> h.backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

このように、g[0]g[1]を含む計算グラフを作って、それについてbackward()を呼べば、

特に問題はありません。普通はこういう使い方をするでしょう。

torch.autograd.grad()は、複数の変数に関して偏微分をとった結果、

つまり勾配をtupleとして返してきます。

このtupleの要素を組み合わせて計算グラフをつくり、

それについてまた微分をとる、という使い方が普通だと思います。

例えば、WGANで使う勾配のL2ノルムの場合も、

torch.autograd.grad()が返してきたtupleの要素を組み合わせて、

計算グラフを作っていることになります。


おわりに

PyTorchの自動微分、便利です。