7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのUserWarning: Using a target size ... that is different to the input size ...の警告を侮ってはいけない

Last updated at Posted at 2021-02-07

#はじめに
掲題の警告を無視して、3時間を浪費してしまった話。

環境

  • PyTorch 1.7.0

#PyTorchの学習コード

以下のようなコードで回帰の予測モデルを作成。


for epoch in range(max_epoch):

    losses = []
    for batch in train_loader:

        # バッチサイズ分のサンプル抽出
        x, t = batch
        
        # パラメータの勾配を初期化
        optimizer.zero_grad()

        # 予測値の算出
        y = net(x)

        # 目的関数の値を算出
        #print(y)
        #print(t)
        loss = F.mse_loss(y, t)

        # 目的関数の値を表示
        #print('loss:', loss.item())
        losses.append(loss.item())
        # 各パラメータの値を算出
        loss.backward()

        # 勾配の情報を用いたパラメータの更新
        optimizer.step()

    avg_loss = torch.tensor(losses).mean()
    print('TRAIN_LOSS={0}, VAL_LOSS={1}'.format(avg_loss, calc_mse(val_loader)))

学習を進めたところ、loss = F.mse_loss(y, t) の箇所で以下の警告が出ていたが、学習は進むので無視して学習を進めたところ、全く精度があがらず、最終的にはこの警告で記載されていることが原因であった。

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:12: UserWarning: Using a target size (torch.Size([12])) that is different to the input size (torch.Size([12, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  if sys.path[0] == '':

#精度が上がらなかった原因

以下のように成果値と予測値のTensorの形状が異なっていたため、mse_lossが正しく計算されていなかった。

正解値

形状:バッチサイズ x 1の行列

tensor([[27.0766],
        [19.2778],
        [27.9340],
        [21.0067],
        [15.3397],
        [20.8324],
        [16.6614],
        ...
        [15.0981],
        [22.3690],
        [12.4326]], grad_fn=<AddmmBackward>)

予測値

形状:ベクトル

tensor([21.2000, 22.7000, 19.9000, 21.8000, 18.5000, 16.1000, 10.9000, 21.0000,
        13.0000, 19.5000, 20.9000, 19.5000, 17.1000, 21.5000, 36.2000, 17.8000,
        ...
        22.9000, 14.5000, 14.9000, 26.4000, 24.4000, 14.1000, 34.9000, 11.8000,
        24.4000, 20.9000,  8.8000, 22.6000, 11.8000, 37.0000, 21.7000, 17.6000,
        11.7000, 19.6000, 19.4000, 36.0000])

#対策

以下のように unsqueezeを使って予測値の形状を、N x 1 の形に変更する。

y = y.unsqueeze(1)

#おわりに
警告は無視せずにしっかり読もう!という話。

7
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?