Python
PyTorch

PyTorchで高階偏微分係数

はじめに

この記事ではPyTorchを使って高階の偏微分を求める方法を説明しています。

高階ではない微分係数

例えば、関数$f(x)=x^3$について、$\left. \frac{df(x)}{dx} \right|_{x=4}$を求めたいとします。

>>> import torch
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> f.backward()
>>> print(x.grad)
tensor([48.])

backward()を呼べばいいだけで、これは皆さん普通に実行していることです。
(以下、import torchは省略します。)

1つの変数だけを扱う場合

2階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における2階の微分係数$\left. \frac{d^2f(x)}{dx^2} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([48.], grad_fn=<ThMulBackward>),)
>>> g.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'tuple' object has no attribute 'backward'
>>> g[0].backward()
>>> x.grad
tensor([24.])

torch.autograd.grad()を呼ぶときにcreate_graph=Trueとしているのがポイントです。
こうすると、微分係数(上の場合は48)だけでなく、
$f$の$x$に関する微分について計算グラフを作って、それも返してくれます。
すると、その計算グラフを使うことで、2階の微分係数が計算できるようになります。

上の例で、gと入力したとき

(tensor([48.], grad_fn=<ThMulBackward>),)

と表示されています。
48は1階の微分係数です。create_graph=Trueと設定しなかったら、これしか返してきません。実際、

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x)
>>> g
(tensor([48.]),)

となります。先ほどの例にあるgrad_fn=<ThMulBackward>というあやしい(?)表示が、
計算グラフも一緒に作ってくれている証拠です。

また、わざとエラーを出してみました:laughing:
今は、$x$というひとつの変数についてしか微分していないのですが、
torch.autograd.grad()では、複数の変数の各々で偏微分する状況を扱うのが基本です。
そのため、返って来るのがいつでもtupleなんですね。
ひとつの変数しか考慮していなくても、要素数がひとつだけのtupleが帰ってきます。
そのため、

>>> g[0].backward()

と、tupleの最初の要素を使う旨、添え字0で指定しないと、エラーになります。

3階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における3階の微分係数$\left. \frac{d^3f(x)}{dx^3} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> h = torch.autograd.grad(g, x, create_graph=True)
>>> h
(tensor([24.], grad_fn=<ThMulBackward>),)
>>> h[0].backward()
>>> x.grad
tensor([6.])

最初にtorch.autograd.grad()を呼んで戻ってきた計算グラフをそのまま、
次のtorch.autograd.grad()の呼び出しで使っています。
こうすると、この2回目の呼び出しは、2階の微分の計算グラフを返してきます。
それについてbackward()すれば、3階の微分係数を計算できます。
答えは6.と表示されています。これは$x^3$を3回$x$で微分すると定数$6$になるためです。

2個以上の変数を扱う場合

2階の偏微分係数

関数$f(x,y)=(x+2w)^3$について、まず、$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])

$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3} = 60$、そして

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3} = 120$です。

では、次に、$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial y^2} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, y, create_graph=True)
>>> g
(tensor([600.], grad_fn=<MulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([120.])
>>> y.grad
tensor([240.])

当然ですが、$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$と

$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$は、同じ値120になります。

上のふたつの作業をまとめて実行しようとすると・・・

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])
>>> g[1].backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g[1].backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

と、エラーが出てしまいました。
それに関して偏微分を求めたい2つの変数を、(x,y)とtupleにして
torch.autograd.grad()に渡すと、

>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))

という箇所にあるように、それぞれの変数に関する1階の微分係数をtupleにして返してくれます。
しかし、g[0].backward()のあとに続けてg[1].backward()を実行したため、エラーになりました。
上の例では

>>> g = torch.autograd.grad(f, (x, y), create_graph=True)

と、もう一度torch.autograd.grad()を呼んでいます。

(注:x.grad.zero_()そしてy.grad.zero_()と、ゼロに初期化する処理は書いていないため、
x.grady.gradも、g[0].backward()g[1].backward()とでの値を
足したものになってしまっています。)

しかし、上のように、torch.autograd.grad()が返したきた勾配の
別々の要素をそれぞれ微分するといった使い方はしないでしょう。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> h = g[0] + g[1]
>>> h.backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

このように、g[0]g[1]を含む計算グラフを作って、それについてbackward()を呼べば、
特に問題はありません。普通はこういう使い方をするでしょう。

torch.autograd.grad()は、複数の変数に関して偏微分をとった結果、
つまり勾配をtupleとして返してきます。
このtupleの要素を組み合わせて計算グラフをつくり、
それについてまた微分をとる、という使い方が普通だと思います。

例えば、WGANで使う勾配のL2ノルムの場合も、
torch.autograd.grad()が返してきたtupleの要素を組み合わせて、
計算グラフを作っていることになります。

おわりに

PyTorchの自動微分、便利です。