0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorch: mat1 and mat2 shapes cannot be multiplicatedの稀なパターン

Last updated at Posted at 2024-09-21

PyTorchの練習をしている際に出たエラーです。

コード

playground.py
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# 中略:ここまででMockDatasetを定義している

# モデル定義とモックデータの準備
model = nn.Linear(1, 1)
data = MockDataset(torch.Tensor([1 for _ in range(100)]))

# 下準備:データローダーを作る
model.eval()
data = DataLoader(data, batch_size=64, shuffle=True)

# モデル評価
with torch.no_grad():
    for input, label in data:
        pred = model(input)
    #   ^^^^^^^^^^^^^^^^^^^ Error: mat1 and mat2...

エラー

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x64 and 1x1)

解決策

入力データを1ではなく、[1]とする。

playground.py
# data = MockDataset(torch.Tensor([1 for _ in range(100)]))
data = MockDataset(torch.Tensor([[1] for _ in range(100)]))
                            #    ^ ^

考察

PyTorchがパラメータを更新するには、配列同士の演算をする必要があります。元々の例では、inputの各要素はただのintであり、配列ではなかったため、エラーが出ていたと思われます。

参考

Qiitaの質問
配列同士の掛け算ができないという根源的な原因を教えてくれました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?