ミニバッチ学習の際、NNの入力の数が合わない
解決したいこと
Nnでミニバッチ学習をさせようとしています。
こちら(https://free.kikagaku.ai/tutorial/basic_of_deep_learning/learn/pytorch_basic)
を参考に、入力の3変数から出力1変数を推定するNNを作成しています。
サンプル数は1036個です。
構築したモデル
import torch.nn as nn
# Networkをclassとして作成する.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Linear(3, 3)
self.fc2 = nn.Linear(3, 1)
def forward(self, x):
x = x.reshape(1, -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
import torch.optim as optim
device = torch.device("cpu")
net = Net()
net = net.to(device)
criterion = nn.MSELoss()
入力
teacher_X = (shape = (1036, 3), dtype= np.float32のnp.ndarray)
teacher_y = (shape = (1036, 1), dtype= np.float32のnp.ndarray)
dataset = torch.utils.data.TensorDataset(teacher_X, teacher_y)
#説明変数とラベルをくっつける
train, test = train_test_split(dataset, test_size=0.5)
#trainデータとtestデータに分ける
trainloader = torch.utils.data.DataLoader(train, batch_size = 2*37, shuffle = True)
testloader = torch.utils.data.DataLoader(test, batch_size = 2*37, shuffle = False)
#バッチサイズ74に分割する
for epoch in range(1):
for batch in trainloader:
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
output_list.append(outputs.item())
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
例)
---> 35 x = self.fc1(x)
---> 20 outputs = net(inputs)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x222 and 3x3)
1x222は74x3なので、バッチサイズが74で説明変数が3列なので入力サイズが222個になるのはわかります。それでinputsを一列にすると3x3にあわなくなるんだと思います。でも、参考にしたサイトではself.fc1 = nn.Linear(4, 4)としてバッチサイズ10でミニバッチ学習をしているのにうまくいっています。
for (inputs, labels) in trainloader:
break
print(inputs)
print(labels)
---出力---
tensor([[-2.4840e-01, -4.5650e-02, -1.0503e+00],
[ 8.9746e-01, 1.1928e+00, 2.3086e+00],
[ 3.2453e-01, 6.6884e-01, -1.0383e+00],
[-8.2133e-01, -1.1412e+00, -7.7657e-01],
[-2.4840e-01, 9.7248e-02, -7.0233e-01],
[-8.2133e-01, -1.2841e+00, -7.0156e-01],
[ 8.9746e-01, 1.1452e+00, -1.0240e+00],
[ 3.2453e-01, 8.5937e-01, -6.6869e-01],
[-8.2133e-01, -7.1251e-01, -7.0272e-01],
[ 2.6163e+00, 1.8597e+00, -6.9924e-01],
[-2.4840e-01, -4.5650e-02, -6.8068e-01],
[ 1.4704e+00, 1.3357e+00, 3.1652e-01],
[-2.4840e-01, 9.7248e-02, -1.0287e+00],
[-2.4840e-01, 1.4488e-01, -1.0256e+00],
[-2.4840e-01, 9.7248e-02, -1.0503e+00],
[ 2.0433e+00, 1.3833e+00, -7.0233e-01],
[ 3.2453e-01, 8.5937e-01, -6.8802e-01],
[ 3.2453e-01, 8.5937e-01, -1.0503e+00],
[-2.4840e-01, 1.9823e-03, -1.0252e+00],
[-8.2133e-01, -7.1251e-01, -6.9924e-01],
[ 8.9746e-01, 1.0975e+00, -1.0383e+00],
[ 2.6163e+00, 1.8597e+00, -6.9924e-01],
[ 3.2453e-01, 6.6884e-01, -9.6990e-01],
[-2.4840e-01, 1.4488e-01, -1.0472e+00],
[-8.2133e-01, -7.1251e-01, 1.7461e-01],
[ 3.2453e-01, 8.5937e-01, -6.9924e-01],
[ 3.2453e-01, 8.5937e-01, -1.0383e+00],
[-8.2133e-01, -7.1251e-01, -1.0240e+00],
[-8.2133e-01, -1.2841e+00, -1.0306e+00],
[-2.4840e-01, 1.4488e-01, -1.0368e+00],
[-2.4840e-01, -4.5650e-02, -6.9924e-01],
[-8.2133e-01, -1.2841e+00, 1.2808e+00],
[-8.2133e-01, -7.1251e-01, 1.2546e+00],
[-8.2133e-01, -8.0777e-01, -6.9924e-01],
[ 1.4704e+00, 1.3357e+00, -1.0472e+00],
[-2.4840e-01, 1.9251e-01, -1.0391e+00],
[ 8.9746e-01, 1.1452e+00, -1.0383e+00],
[-8.2133e-01, -9.9830e-01, 2.2367e+00],
[ 2.6163e+00, 1.8597e+00, -1.0472e+00],
[ 8.9746e-01, 1.0499e+00, -6.6947e-01],
[-2.4840e-01, 1.4488e-01, -1.0356e+00],
[ 8.9746e-01, 1.1452e+00, 2.0787e-01],
[ 8.9746e-01, 1.2404e+00, -1.0472e+00],
[-8.2133e-01, -7.1251e-01, -1.0364e+00],
[ 2.6163e+00, 1.8597e+00, -1.0480e+00],
[-8.2133e-01, -9.0304e-01, -6.9924e-01],
[ 8.9746e-01, 1.1452e+00, -1.0472e+00],
[-8.2133e-01, -1.0459e+00, -1.0279e+00],
[ 8.9746e-01, 1.1928e+00, -1.0472e+00],
[-8.2133e-01, -1.0459e+00, -6.9924e-01],
[-2.4840e-01, 1.4488e-01, -1.0267e+00],
[-8.2133e-01, -7.1251e-01, -7.0272e-01],
[-2.4840e-01, 9.7248e-02, -1.0472e+00],
[-2.4840e-01, -2.8381e-01, -7.0233e-01],
[-2.4840e-01, 9.7248e-02, -1.0472e+00],
[-2.4840e-01, 9.7248e-02, -1.0472e+00],
[-2.4840e-01, 1.4488e-01, 2.4145e+00],
[ 8.9746e-01, 1.0499e+00, -6.6947e-01],
[-2.4840e-01, -2.8381e-01, -7.0078e-01],
[ 3.2453e-01, 8.5937e-01, -1.0167e+00],
[ 3.2453e-01, 8.5937e-01, 1.6319e+00],
[-2.4840e-01, 9.7248e-02, -6.9924e-01],
[-8.2133e-01, -7.1251e-01, -6.7913e-01],
[-8.2133e-01, -1.0459e+00, -1.0418e+00],
[ 8.9746e-01, 1.0499e+00, 5.6785e-01],
[-8.2133e-01, -7.1251e-01, -1.0349e+00],
[-2.4840e-01, 1.9823e-03, -6.7875e-01],
[-8.2133e-01, -7.1251e-01, -7.0233e-01],
[-8.2133e-01, -9.9830e-01, -1.0264e+00],
[ 1.4704e+00, 1.3357e+00, -1.0472e+00],
[-2.4840e-01, -2.3618e-01, -6.7875e-01],
[-2.4840e-01, 1.9823e-03, 2.2565e-01],
[ 2.6163e+00, 1.8597e+00, -1.0472e+00],
[ 3.2453e-01, 6.6884e-01, 4.6074e-01]])
tensor([-0.1082, 0.0319, 0.0195, 0.0483, 0.2727, -0.0772, 0.0319, 0.0294,
0.0039, 0.0032, -0.4534, 0.0392, 0.0242, -0.0643, 0.1256, 0.0039,
-0.0290, 0.0392, 0.2644, 0.0690, 0.0232, -0.0290, 0.0394, 0.1216,
0.0429, -0.0664, -0.0841, 0.0420, -0.0397, -0.0127, 0.0039, 0.0036,
0.3039, 0.0319, 0.1322, -0.0994, 0.0392, 0.0015, 0.0195, 0.0561,
-0.0546, -0.0596, 0.0472, -0.0190, -0.0119, 0.0733, 0.0008, 0.2031,
0.1742, -0.0218, 0.3039, -0.1418, -0.0641, 0.0671, 0.4332, 0.0111,
0.3039, 0.1665, 0.0903, 0.0294, -0.0290, -0.0218, 0.2644, 0.0653,
0.0251, -0.0397, -0.0982, 0.0227, -0.1652, 0.3822, -0.0301, -0.0841,
0.0733, 0.0394])
どうかお力添えいただけるとありがたいです。