41
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorchで高階偏微分係数

Last updated at Posted at 2018-11-09

はじめに

この記事ではPyTorchを使って高階の偏微分を求める方法を説明しています。

高階ではない微分係数

例えば、関数$f(x)=x^3$について、$\left. \frac{df(x)}{dx} \right|_{x=4}$を求めたいとします。

>>> import torch
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> f.backward()
>>> print(x.grad)
tensor([48.])

backward()を呼べばいいだけで、これは皆さん普通に実行していることです。
(以下、import torchは省略します。)

1つの変数だけを扱う場合

2階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における2階の微分係数$\left. \frac{d^2f(x)}{dx^2} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([48.], grad_fn=<ThMulBackward>),)
>>> g.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'tuple' object has no attribute 'backward'
>>> g[0].backward()
>>> x.grad
tensor([24.])

torch.autograd.grad()を呼ぶときにcreate_graph=Trueとしているのがポイントです。
こうすると、微分係数(上の場合は48)だけでなく、
$f$の$x$に関する微分について計算グラフを作って、それも返してくれます。
すると、その計算グラフを使うことで、2階の微分係数が計算できるようになります。

上の例で、gと入力したとき

(tensor([48.], grad_fn=<ThMulBackward>),)

と表示されています。
48は1階の微分係数です。create_graph=Trueと設定しなかったら、これしか返してきません。実際、

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x)
>>> g
(tensor([48.]),)

となります。先ほどの例にあるgrad_fn=<ThMulBackward>というあやしい(?)表示が、
計算グラフも一緒に作ってくれている証拠です。

また、わざとエラーを出してみました:laughing:
今は、$x$というひとつの変数についてしか微分していないのですが、
torch.autograd.grad()では、複数の変数の各々で偏微分する状況を扱うのが基本です。
そのため、返って来るのがいつでもtupleなんですね。
ひとつの変数しか考慮していなくても、要素数がひとつだけのtupleが帰ってきます。
そのため、

>>> g[0].backward()

と、tupleの最初の要素を使う旨、添え字0で指定しないと、エラーになります。

3階の微分係数

先ほどの関数$f(x)=x^3$について、$x=4$における3階の微分係数$\left. \frac{d^3f(x)}{dx^3} \right|_{x=4}$を求めたいとします。

>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> h = torch.autograd.grad(g, x, create_graph=True)
>>> h
(tensor([24.], grad_fn=<ThMulBackward>),)
>>> h[0].backward()
>>> x.grad
tensor([6.])

最初にtorch.autograd.grad()を呼んで戻ってきた計算グラフをそのまま、
次のtorch.autograd.grad()の呼び出しで使っています。
こうすると、この2回目の呼び出しは、2階の微分の計算グラフを返してきます。
それについてbackward()すれば、3階の微分係数を計算できます。
答えは6.と表示されています。これは$x^3$を3回$x$で微分すると定数$6$になるためです。

2個以上の変数を扱う場合

2階の偏微分係数

関数$f(x,y)=(x+2w)^3$について、まず、$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])

$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3} = 60$、そして

$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3} = 120$です。

では、次に、$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$と

$\left. \frac{\partial^2 f(x,y)}{\partial y^2} \right|_{x=4, y=3}$を求めてみます。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, y, create_graph=True)
>>> g
(tensor([600.], grad_fn=<MulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([120.])
>>> y.grad
tensor([240.])

当然ですが、$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$と

$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$は、同じ値120になります。

上のふたつの作業をまとめて実行しようとすると・・・

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])
>>> g[1].backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g[1].backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

と、エラーが出てしまいました。
それに関して偏微分を求めたい2つの変数を、(x,y)とtupleにして
torch.autograd.grad()に渡すと、

>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))

という箇所にあるように、それぞれの変数に関する1階の微分係数をtupleにして返してくれます。
しかし、g[0].backward()のあとに続けてg[1].backward()を実行したため、エラーになりました。
上の例では

>>> g = torch.autograd.grad(f, (x, y), create_graph=True)

と、もう一度torch.autograd.grad()を呼んでいます。

(注:x.grad.zero_()そしてy.grad.zero_()と、ゼロに初期化する処理は書いていないため、
x.grady.gradも、g[0].backward()g[1].backward()とでの値を
足したものになってしまっています。)

しかし、上のように、torch.autograd.grad()が返したきた勾配の
別々の要素をそれぞれ微分するといった使い方はしないでしょう。

>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> h = g[0] + g[1]
>>> h.backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])

このように、g[0]g[1]を含む計算グラフを作って、それについてbackward()を呼べば、
特に問題はありません。普通はこういう使い方をするでしょう。

torch.autograd.grad()は、複数の変数に関して偏微分をとった結果、
つまり勾配をtupleとして返してきます。
このtupleの要素を組み合わせて計算グラフをつくり、
それについてまた微分をとる、という使い方が普通だと思います。

例えば、WGANで使う勾配のL2ノルムの場合も、
torch.autograd.grad()が返してきたtupleの要素を組み合わせて、
計算グラフを作っていることになります。

おわりに

PyTorchの自動微分、便利です。

41
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?