はじめに
この記事ではPyTorchを使って高階の偏微分を求める方法を説明しています。
高階ではない微分係数
例えば、関数$f(x)=x^3$について、$\left. \frac{df(x)}{dx} \right|_{x=4}$を求めたいとします。
>>> import torch
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> f.backward()
>>> print(x.grad)
tensor([48.])
backward()
を呼べばいいだけで、これは皆さん普通に実行していることです。
(以下、import torch
は省略します。)
1つの変数だけを扱う場合
2階の微分係数
先ほどの関数$f(x)=x^3$について、$x=4$における2階の微分係数$\left. \frac{d^2f(x)}{dx^2} \right|_{x=4}$を求めたいとします。
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([48.], grad_fn=<ThMulBackward>),)
>>> g.backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'tuple' object has no attribute 'backward'
>>> g[0].backward()
>>> x.grad
tensor([24.])
torch.autograd.grad()
を呼ぶときにcreate_graph=True
としているのがポイントです。
こうすると、微分係数(上の場合は48)だけでなく、
$f$の$x$に関する微分について計算グラフを作って、それも返してくれます。
すると、その計算グラフを使うことで、2階の微分係数が計算できるようになります。
上の例で、g
と入力したとき
(tensor([48.], grad_fn=<ThMulBackward>),)
と表示されています。
48は1階の微分係数です。create_graph=True
と設定しなかったら、これしか返してきません。実際、
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x)
>>> g
(tensor([48.]),)
となります。先ほどの例にあるgrad_fn=<ThMulBackward>
というあやしい(?)表示が、
計算グラフも一緒に作ってくれている証拠です。
また、わざとエラーを出してみました
今は、$x$というひとつの変数についてしか微分していないのですが、
torch.autograd.grad()
では、複数の変数の各々で偏微分する状況を扱うのが基本です。
そのため、返って来るのがいつでもtupleなんですね。
ひとつの変数しか考慮していなくても、要素数がひとつだけのtupleが帰ってきます。
そのため、
>>> g[0].backward()
と、tupleの最初の要素を使う旨、添え字0
で指定しないと、エラーになります。
3階の微分係数
先ほどの関数$f(x)=x^3$について、$x=4$における3階の微分係数$\left. \frac{d^3f(x)}{dx^3} \right|_{x=4}$を求めたいとします。
>>> x = torch.tensor([4.], requires_grad=True)
>>> f = x ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> h = torch.autograd.grad(g, x, create_graph=True)
>>> h
(tensor([24.], grad_fn=<ThMulBackward>),)
>>> h[0].backward()
>>> x.grad
tensor([6.])
最初にtorch.autograd.grad()
を呼んで戻ってきた計算グラフをそのまま、
次のtorch.autograd.grad()
の呼び出しで使っています。
こうすると、この2回目の呼び出しは、2階の微分の計算グラフを返してきます。
それについてbackward()
すれば、3階の微分係数を計算できます。
答えは6.
と表示されています。これは$x^3$を3回$x$で微分すると定数$6$になるためです。
2個以上の変数を扱う場合
2階の偏微分係数
関数$f(x,y)=(x+2w)^3$について、まず、$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3}$と
$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$を求めてみます。
>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, x, create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])
$\left.\frac{\partial^2f(x,y)}{\partial x^2}\right|_{x=4,y=3} = 60$、そして
$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3} = 120$です。
では、次に、$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$と
$\left. \frac{\partial^2 f(x,y)}{\partial y^2} \right|_{x=4, y=3}$を求めてみます。
>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, y, create_graph=True)
>>> g
(tensor([600.], grad_fn=<MulBackward>),)
>>> g[0].backward()
>>> x.grad
tensor([120.])
>>> y.grad
tensor([240.])
当然ですが、$\left. \frac{\partial^2 f(x,y)}{\partial x \partial y} \right|_{x=4, y=3}$と
$\left.\frac{\partial^2f(x,y)}{\partial y \partial x}\right|_{x=4,y=3}$は、同じ値120になります。
上のふたつの作業をまとめて実行しようとすると・・・
>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))
>>> g[0].backward()
>>> x.grad
tensor([60.])
>>> y.grad
tensor([120.])
>>> g[1].backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/data10/masada/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> g[1].backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])
と、エラーが出てしまいました。
それに関して偏微分を求めたい2つの変数を、(x,y)
とtupleにして
torch.autograd.grad()
に渡すと、
>>> g
(tensor([300.], grad_fn=<ThMulBackward>), tensor([600.], grad_fn=<MulBackward>))
という箇所にあるように、それぞれの変数に関する1階の微分係数をtupleにして返してくれます。
しかし、g[0].backward()
のあとに続けてg[1].backward()
を実行したため、エラーになりました。
上の例では
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
と、もう一度torch.autograd.grad()
を呼んでいます。
(注:x.grad.zero_()
そしてy.grad.zero_()
と、ゼロに初期化する処理は書いていないため、
x.grad
もy.grad
も、g[0].backward()
とg[1].backward()
とでの値を
足したものになってしまっています。)
しかし、上のように、torch.autograd.grad()
が返したきた勾配の
別々の要素をそれぞれ微分するといった使い方はしないでしょう。
>>> x = torch.tensor([4.], requires_grad=True)
>>> y = torch.tensor([3.], requires_grad=True)
>>> f = (x + 2 * y) ** 3
>>> g = torch.autograd.grad(f, (x, y), create_graph=True)
>>> h = g[0] + g[1]
>>> h.backward()
>>> x.grad
tensor([180.])
>>> y.grad
tensor([360.])
このように、g[0]
とg[1]
を含む計算グラフを作って、それについてbackward()
を呼べば、
特に問題はありません。普通はこういう使い方をするでしょう。
torch.autograd.grad()
は、複数の変数に関して偏微分をとった結果、
つまり勾配をtupleとして返してきます。
このtupleの要素を組み合わせて計算グラフをつくり、
それについてまた微分をとる、という使い方が普通だと思います。
例えば、WGANで使う勾配のL2ノルムの場合も、
torch.autograd.grad()
が返してきたtupleの要素を組み合わせて、
計算グラフを作っていることになります。
おわりに
PyTorchの自動微分、便利です。