terasima712
@terasima712 (ゆき 寺島)

Are you sure you want to delete the question?

If your question is resolved, you may close it.

Leaving a resolved question undeleted may help others!

We hope you find it useful!

ミニバッチ学習の際、NNの入力の数が合わない

解決したいこと

Nnでミニバッチ学習をさせようとしています。
こちら(https://free.kikagaku.ai/tutorial/basic_of_deep_learning/learn/pytorch_basic)
を参考に、入力の3変数から出力1変数を推定するNNを作成しています。
サンプル数は1036個です。

構築したモデル

import torch.nn as nn

# Networkをclassとして作成する.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__() 

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(3, 3)
        self.fc2 = nn.Linear(3, 1)

    def forward(self, x):
        x = x.reshape(1, -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
import torch.optim as optim

device = torch.device("cpu")
net = Net()
net = net.to(device)
criterion = nn.MSELoss()

入力

teacher_X = (shape = (1036, 3), dtype= np.float32のnp.ndarray)
teacher_y = (shape = (1036, 1), dtype= np.float32のnp.ndarray)

dataset = torch.utils.data.TensorDataset(teacher_X, teacher_y)
#説明変数とラベルをくっつける

train, test = train_test_split(dataset, test_size=0.5)
#trainデータとtestデータに分ける

trainloader = torch.utils.data.DataLoader(train, batch_size = 2*37, shuffle = True)
testloader = torch.utils.data.DataLoader(test, batch_size = 2*37, shuffle = False)
#バッチサイズ74に分割する

for epoch in range(1):
    for batch in trainloader:
      inputs, labels = batch
      inputs, labels = inputs.to(device), labels.to(device) 
      optimizer.zero_grad() 
      outputs = net(inputs)
      output_list.append(outputs.item())
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

例)

---> 35         x = self.fc1(x)
---> 20       outputs = net(inputs)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x222 and 3x3)

1x222は74x3なので、バッチサイズが74で説明変数が3列なので入力サイズが222個になるのはわかります。それでinputsを一列にすると3x3にあわなくなるんだと思います。でも、参考にしたサイトではself.fc1 = nn.Linear(4, 4)としてバッチサイズ10でミニバッチ学習をしているのにうまくいっています。

for (inputs, labels) in trainloader:
    break
print(inputs)
print(labels)
---出力---
tensor([[-2.4840e-01, -4.5650e-02, -1.0503e+00],
        [ 8.9746e-01,  1.1928e+00,  2.3086e+00],
        [ 3.2453e-01,  6.6884e-01, -1.0383e+00],
        [-8.2133e-01, -1.1412e+00, -7.7657e-01],
        [-2.4840e-01,  9.7248e-02, -7.0233e-01],
        [-8.2133e-01, -1.2841e+00, -7.0156e-01],
        [ 8.9746e-01,  1.1452e+00, -1.0240e+00],
        [ 3.2453e-01,  8.5937e-01, -6.6869e-01],
        [-8.2133e-01, -7.1251e-01, -7.0272e-01],
        [ 2.6163e+00,  1.8597e+00, -6.9924e-01],
        [-2.4840e-01, -4.5650e-02, -6.8068e-01],
        [ 1.4704e+00,  1.3357e+00,  3.1652e-01],
        [-2.4840e-01,  9.7248e-02, -1.0287e+00],
        [-2.4840e-01,  1.4488e-01, -1.0256e+00],
        [-2.4840e-01,  9.7248e-02, -1.0503e+00],
        [ 2.0433e+00,  1.3833e+00, -7.0233e-01],
        [ 3.2453e-01,  8.5937e-01, -6.8802e-01],
        [ 3.2453e-01,  8.5937e-01, -1.0503e+00],
        [-2.4840e-01,  1.9823e-03, -1.0252e+00],
        [-8.2133e-01, -7.1251e-01, -6.9924e-01],
        [ 8.9746e-01,  1.0975e+00, -1.0383e+00],
        [ 2.6163e+00,  1.8597e+00, -6.9924e-01],
        [ 3.2453e-01,  6.6884e-01, -9.6990e-01],
        [-2.4840e-01,  1.4488e-01, -1.0472e+00],
        [-8.2133e-01, -7.1251e-01,  1.7461e-01],
        [ 3.2453e-01,  8.5937e-01, -6.9924e-01],
        [ 3.2453e-01,  8.5937e-01, -1.0383e+00],
        [-8.2133e-01, -7.1251e-01, -1.0240e+00],
        [-8.2133e-01, -1.2841e+00, -1.0306e+00],
        [-2.4840e-01,  1.4488e-01, -1.0368e+00],
        [-2.4840e-01, -4.5650e-02, -6.9924e-01],
        [-8.2133e-01, -1.2841e+00,  1.2808e+00],
        [-8.2133e-01, -7.1251e-01,  1.2546e+00],
        [-8.2133e-01, -8.0777e-01, -6.9924e-01],
        [ 1.4704e+00,  1.3357e+00, -1.0472e+00],
        [-2.4840e-01,  1.9251e-01, -1.0391e+00],
        [ 8.9746e-01,  1.1452e+00, -1.0383e+00],
        [-8.2133e-01, -9.9830e-01,  2.2367e+00],
        [ 2.6163e+00,  1.8597e+00, -1.0472e+00],
        [ 8.9746e-01,  1.0499e+00, -6.6947e-01],
        [-2.4840e-01,  1.4488e-01, -1.0356e+00],
        [ 8.9746e-01,  1.1452e+00,  2.0787e-01],
        [ 8.9746e-01,  1.2404e+00, -1.0472e+00],
        [-8.2133e-01, -7.1251e-01, -1.0364e+00],
        [ 2.6163e+00,  1.8597e+00, -1.0480e+00],
        [-8.2133e-01, -9.0304e-01, -6.9924e-01],
        [ 8.9746e-01,  1.1452e+00, -1.0472e+00],
        [-8.2133e-01, -1.0459e+00, -1.0279e+00],
        [ 8.9746e-01,  1.1928e+00, -1.0472e+00],
        [-8.2133e-01, -1.0459e+00, -6.9924e-01],
        [-2.4840e-01,  1.4488e-01, -1.0267e+00],
        [-8.2133e-01, -7.1251e-01, -7.0272e-01],
        [-2.4840e-01,  9.7248e-02, -1.0472e+00],
        [-2.4840e-01, -2.8381e-01, -7.0233e-01],
        [-2.4840e-01,  9.7248e-02, -1.0472e+00],
        [-2.4840e-01,  9.7248e-02, -1.0472e+00],
        [-2.4840e-01,  1.4488e-01,  2.4145e+00],
        [ 8.9746e-01,  1.0499e+00, -6.6947e-01],
        [-2.4840e-01, -2.8381e-01, -7.0078e-01],
        [ 3.2453e-01,  8.5937e-01, -1.0167e+00],
        [ 3.2453e-01,  8.5937e-01,  1.6319e+00],
        [-2.4840e-01,  9.7248e-02, -6.9924e-01],
        [-8.2133e-01, -7.1251e-01, -6.7913e-01],
        [-8.2133e-01, -1.0459e+00, -1.0418e+00],
        [ 8.9746e-01,  1.0499e+00,  5.6785e-01],
        [-8.2133e-01, -7.1251e-01, -1.0349e+00],
        [-2.4840e-01,  1.9823e-03, -6.7875e-01],
        [-8.2133e-01, -7.1251e-01, -7.0233e-01],
        [-8.2133e-01, -9.9830e-01, -1.0264e+00],
        [ 1.4704e+00,  1.3357e+00, -1.0472e+00],
        [-2.4840e-01, -2.3618e-01, -6.7875e-01],
        [-2.4840e-01,  1.9823e-03,  2.2565e-01],
        [ 2.6163e+00,  1.8597e+00, -1.0472e+00],
        [ 3.2453e-01,  6.6884e-01,  4.6074e-01]])
tensor([-0.1082,  0.0319,  0.0195,  0.0483,  0.2727, -0.0772,  0.0319,  0.0294,
         0.0039,  0.0032, -0.4534,  0.0392,  0.0242, -0.0643,  0.1256,  0.0039,
        -0.0290,  0.0392,  0.2644,  0.0690,  0.0232, -0.0290,  0.0394,  0.1216,
         0.0429, -0.0664, -0.0841,  0.0420, -0.0397, -0.0127,  0.0039,  0.0036,
         0.3039,  0.0319,  0.1322, -0.0994,  0.0392,  0.0015,  0.0195,  0.0561,
        -0.0546, -0.0596,  0.0472, -0.0190, -0.0119,  0.0733,  0.0008,  0.2031,
         0.1742, -0.0218,  0.3039, -0.1418, -0.0641,  0.0671,  0.4332,  0.0111,
         0.3039,  0.1665,  0.0903,  0.0294, -0.0290, -0.0218,  0.2644,  0.0653,
         0.0251, -0.0397, -0.0982,  0.0227, -0.1652,  0.3822, -0.0301, -0.0841,
         0.0733,  0.0394])

どうかお力添えいただけるとありがたいです。

0

1Answer

x = x.reshape(x.shape[0], -1)にしたら直ると思います.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x222 and 3x3)

と言うエラーに対して

1x222は74x3なので、バッチサイズが74で説明変数が3列なので入力サイズが222個になるのはわかります。それでinputsを一列にすると3x3にあわなくなるんだと思います。

で見るべきは,1x222です.これは,x.reshape(1, -1)の結果でこうなっています.詳しくはreshapeの仕様を確認してください.

さて,この次のレイヤとして宣言したself.fc1 = nn.Linear(3, 3)3x3の行列を持つので,222に対して演算(ここではmultiplied)できる列がなくエラーになります.

参考にしたサイトではself.fc1 = nn.Linear(4, 4)としてバッチサイズ10でミニバッチ学習をしているのにうまくいっています。

バッチサイズはnn.Linearの演算に直接関係するべきではありません.なので,これは当然の話です.

バッチサイズx.shape[0]に相当する箇所をいじるようなx.reshape(1, -1)という操作をしてしまったため起きているのが現状です.

2Like

Comments

  1. PyTorchはあまりにも自由度が高く,機械学習の原理から深く学んでないのであればこういうところで挫折しかねないので,より抽象度を上げたKerasから入ることを推奨しておきます.

Your answer might help someone💌