はじめに
pytorch において,複素ニューラルネットワークの重みの更新は以下の式で行われます [1].
$$
w \leftarrow w - \frac{\partial L}{\partial w^*}
$$
このとき,保持される重みパラメータの勾配は複素共役をつけているのか気になり調査しました.
検証
一層の全結合層を複素数の重みで定義します.計算しやすいように重みパラメータは1にしておきます.
import torch
import torch.nn as nn
from torchcvnn.nn import ComplexMSELoss
net = nn.Linear(1,1, dtype=torch.complex128, bias=False)
nn.init.ones_(net.weight)
x = torch.tensor([1+1j], dtype=torch.complex128)
y = torch.tensor([1], dtype=torch.complex128)
loss_fn = ComplexMSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1)
y_hat = net(x)
loss = loss_fn(y, y_hat)
y_hat.retain_grad()
loss.retain_grad()
loss.backward()
optimizer.step()
print(loss) # torch.tensor([1.+1.j], dtype=torch.float128, grad_fn=<...>)
print(loss.grad) # torch.tensor(1., dtype=torch.float64, grad_fn=<MeanBackward0>)
print(y_hat.grad) # torch.tensor([0.+2.j], dtype=torch.float128, grad_fn=<...>)
for p in net.parameters():
print(p.grad)
# tensor([[2.+2.j]], dtype=torch.complex128)
以下の式で順伝播が計算されます.
$$
\hat{y} = w x
$$
$$
L = |y - \hat{y}|^2
$$
今回の例では
$$
\hat{y} = 1 \times (1+1\mathrm{j}) = 1+1\mathrm{j}
$$
$$
L = |1 - (1+1\mathrm{j})|^2 = 1
$$
一方,逆伝播は文献 [2]を参考すると,関数 $s=f(z) $ について,
$$
\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial s}\frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*}\frac{\partial s^*}{\partial z^*}
$$
特に,損失関数のように $f:\mathbb{C}\to\mathbb{R}$ の場合には $s=s^*$ なので
$$
\frac{\partial L}{\partial z^*} = 2\frac{\partial L}{\partial s}\frac{\partial s}{\partial z^*}
$$
となります.これを今回の場合に適用すると,$s=|y-\hat{y}|^2 = L$ であるので,
$$
\begin{align}
\frac{\partial L}{\partial \hat{y}^*} &= 2 \times 1 \times \frac{\partial s}{\partial \hat{y}^*} \\
&= 2 \times \frac{\partial }{\partial \hat{y}^*} (yy^* - y\hat{y}^* - y^*\hat{y} + \hat{y}\hat{y}^*) \\
&= 2 \times (-y+\hat{y}) \\
&= 2 \times (-1 + (1+1\mathrm{j})) \\
&= 2\mathrm{j}
\end{align}
$$
さらに,
$$
\begin{align}
\frac{\partial L}{\partial \hat{w}^*} &= \frac{\partial L}{\partial \hat{y}^*}\frac{\partial \hat{y}^*}{\partial \hat{w}^*} \\
&= 2\mathrm{j} \times x^* \\
&= 2\mathrm{j} \times (1 + 1\mathrm{j})^* \\
&= 2+2\mathrm{j}
\end{align}
$$
pytorch の結果と比較すると,y_hat.grad=0.+2.j
,w.grad=2.+2.j
であるため,pytorch が保持しているのは勾配そのものではなく,勾配の複素共役です.
結論
pytorch の重みパラメータの勾配 w.grad
は重みパラメータの勾配そのものではなく,その複素共役です.
参考文献
- [1] Autograd for Complex Numbers https://docs.pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers
- [2] What about cross-domain functions? https://docs.pytorch.org/docs/stable/notes/autograd.html#what-about-cross-domain-functions
Appendix
$\frac{\partial L}{\partial \hat{y}^*}$ を定義通り計算すると,pytorch と異なる結果が得られます.具体的には,
$$
\begin{align}
\frac{\partial L}{\partial \hat{y}^*} &= \frac{\partial}{\partial \hat{y}^*}|y-\hat{y}|^2 \\
&=\frac{\partial}{\partial \hat{y}^*} (yy^* - y\hat{y}^* - y^*\hat{y} + \hat{y}\hat{y}^*) \\
&= \hat{y} - y
\end{align}
$$
となり,pytorch 実装の半分の値になります.おそらくこれは損失関数が $\mathbb{C}\to\mathbb{R}$になっている,つまり,$s$ と $s^*$ の独立性が満たされないにも関わらず,独立である前提で pytorch 実装されているのが原因かと思います.実際のところ,最適化にこれが悪影響を及ぼす可能性は少ないのでしょう.そもそも重み更新は学習率をかけた値で更新するので,2倍程度の勾配の違いはたかが知れているからです.