0
0

RuntimeError: mat1 and mat2 must have the same dtype, but got ComplexFloat and Float の直し方

Last updated at Posted at 2024-02-12

はじめに

pytorchで複素ニューラルネットワークの実装をしたときのエラーとその対策のメモ書きです.
pytorchでも複素数使えるのでみんな複素ニューラルネットワークやろう!

環境

  • Python 3.10.9
  • torch==2.1.2

具体例

x = torch.randn(3, 2)
x_complex = torch.view_as_complex(x)

l1 = nn.Linear(3,3)
y = l1(x_complex)

# 以下のエラーが発生.実際はもっと長く,最後の部分だけ抜粋した.
# File hoge\venv\lib\site-packages\torch\nn\modules\linear.py:114, in # # Linear.forward(self, input)
#     113 def forward(self, input: Tensor) -> Tensor:
# --> 114     return F.linear(input, self.weight, self.bias)
# 
# RuntimeError: mat1 and mat2 must have the same dtype, but got ComplexFloat and Float

原因

nn.linear は暗黙のうちに実数入力を想定しているため.
mat1, mat2 は nn.linear の入出力の変数のことを表していると思われる.

対策

nn.linear の型を明示的に変更しておけばよい.

x = torch.randn(3, 2)
x_complex = torch.view_as_complex(x)

l1 = nn.Linear(3,3,dtype=torch.complex64)
y = l1(x_complex)

参考文献

  1. https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0