はじめに
pytorchで複素ニューラルネットワークの実装をしたときのエラーとその対策のメモ書きです.
pytorchでも複素数使えるのでみんな複素ニューラルネットワークやろう!
環境
- Python 3.10.9
- torch==2.1.2
具体例
x = torch.randn(3, 2)
x_complex = torch.view_as_complex(x)
l1 = nn.Linear(3,3)
y = l1(x_complex)
# 以下のエラーが発生.実際はもっと長く,最後の部分だけ抜粋した.
# File hoge\venv\lib\site-packages\torch\nn\modules\linear.py:114, in # # Linear.forward(self, input)
# 113 def forward(self, input: Tensor) -> Tensor:
# --> 114 return F.linear(input, self.weight, self.bias)
#
# RuntimeError: mat1 and mat2 must have the same dtype, but got ComplexFloat and Float
原因
nn.linear
は暗黙のうちに実数入力を想定しているため.
mat1, mat2 は nn.linear
の入出力の変数のことを表していると思われる.
対策
nn.linear
の型を明示的に変更しておけばよい.
x = torch.randn(3, 2)
x_complex = torch.view_as_complex(x)
l1 = nn.Linear(3,3,dtype=torch.complex64)
y = l1(x_complex)