0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

pytorch における複素の重みの勾配は複素共役を保持している

Last updated at Posted at 2025-06-16

はじめに

pytorch において,複素ニューラルネットワークの重みの更新は以下の式で行われます [1].
$$
w \leftarrow w - \frac{\partial L}{\partial w^*}
$$

このとき,保持される重みパラメータの勾配は複素共役をつけているのか気になり調査しました.

検証

一層の全結合層を複素数の重みで定義します.計算しやすいように重みパラメータは1にしておきます.

import torch
import torch.nn as nn
from torchcvnn.nn import ComplexMSELoss

net = nn.Linear(1,1, dtype=torch.complex128, bias=False)
nn.init.ones_(net.weight)
x = torch.tensor([1+1j], dtype=torch.complex128)
y = torch.tensor([1], dtype=torch.complex128)
loss_fn = ComplexMSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1)

y_hat = net(x)
loss = loss_fn(y, y_hat)
y_hat.retain_grad()
loss.retain_grad()
loss.backward()
optimizer.step()

print(loss) # torch.tensor([1.+1.j], dtype=torch.float128, grad_fn=<...>)
print(loss.grad) # torch.tensor(1., dtype=torch.float64, grad_fn=<MeanBackward0>)
print(y_hat.grad) # torch.tensor([0.+2.j], dtype=torch.float128, grad_fn=<...>)

for p in net.parameters():
    print(p.grad)
    # tensor([[2.+2.j]], dtype=torch.complex128)

以下の式で順伝播が計算されます.
$$
\hat{y} = w x
$$

$$
L = |y - \hat{y}|^2
$$
今回の例では
$$
\hat{y} = 1 \times (1+1\mathrm{j}) = 1+1\mathrm{j}
$$

$$
L = |1 - (1+1\mathrm{j})|^2 = 1
$$

一方,逆伝播は文献 [2]を参考すると,関数 $s=f(z) $ について,
$$
\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial s}\frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*}\frac{\partial s^*}{\partial z^*}
$$
特に,損失関数のように $f:\mathbb{C}\to\mathbb{R}$ の場合には $s=s^*$ なので
$$
\frac{\partial L}{\partial z^*} = 2\frac{\partial L}{\partial s}\frac{\partial s}{\partial z^*}
$$
となります.これを今回の場合に適用すると,$s=|y-\hat{y}|^2 = L$ であるので,
$$
\begin{align}
\frac{\partial L}{\partial \hat{y}^*} &= 2 \times 1 \times \frac{\partial s}{\partial \hat{y}^*} \\
&= 2 \times \frac{\partial }{\partial \hat{y}^*} (yy^* - y\hat{y}^* - y^*\hat{y} + \hat{y}\hat{y}^*) \\
&= 2 \times (-y+\hat{y}) \\
&= 2 \times (-1 + (1+1\mathrm{j})) \\
&= 2\mathrm{j}
\end{align}
$$
さらに,
$$
\begin{align}
\frac{\partial L}{\partial \hat{w}^*} &= \frac{\partial L}{\partial \hat{y}^*}\frac{\partial \hat{y}^*}{\partial \hat{w}^*} \\
&= 2\mathrm{j} \times x^* \\
&= 2\mathrm{j} \times (1 + 1\mathrm{j})^* \\
&= 2+2\mathrm{j}
\end{align}
$$

pytorch の結果と比較すると,y_hat.grad=0.+2.jw.grad=2.+2.j であるため,pytorch が保持しているのは勾配そのものではなく,勾配の複素共役です.

結論

pytorch の重みパラメータの勾配 w.grad は重みパラメータの勾配そのものではなく,その複素共役です.

参考文献

Appendix

$\frac{\partial L}{\partial \hat{y}^*}$ を定義通り計算すると,pytorch と異なる結果が得られます.具体的には,
$$
\begin{align}
\frac{\partial L}{\partial \hat{y}^*} &= \frac{\partial}{\partial \hat{y}^*}|y-\hat{y}|^2 \\
&=\frac{\partial}{\partial \hat{y}^*} (yy^* - y\hat{y}^* - y^*\hat{y} + \hat{y}\hat{y}^*) \\
&= \hat{y} - y
\end{align}
$$
となり,pytorch 実装の半分の値になります.おそらくこれは損失関数が $\mathbb{C}\to\mathbb{R}$になっている,つまり,$s$ と $s^*$ の独立性が満たされないにも関わらず,独立である前提で pytorch 実装されているのが原因かと思います.実際のところ,最適化にこれが悪影響を及ぼす可能性は少ないのでしょう.そもそも重み更新は学習率をかけた値で更新するので,2倍程度の勾配の違いはたかが知れているからです.

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?