LoginSignup
1
2

More than 3 years have passed since last update.

PyTorchでバッチノーマライズをやってみた(+注意点)

Last updated at Posted at 2021-01-17

はじめに

バッチノーマライズがよくわからなかったのでPyTorchでやってみた。
その結果、入力データについて列単位で平均0、分散1に揃えるものだと理解した。
また動かしてみて気が付いた注意点があるのでメモっておく。

やってみる

まずはインポート

import torch
import torch.nn.functional as F
from torch import nn

入力データサイズを決めて、適当に値を生成

input_samples = 100
input_features = 10 
x = torch.rand((input_samples,input_features)) * 10

あまりばらけてはいないが、各列平均と分散が異なるデータが生成される。

平均

torch.mean(x, 0)
tensor([5.0644, 5.0873, 5.0446, 5.3872, 5.2406, 5.3518, 5.3203, 4.9909, 5.0590,
        5.2169])

分散

torch.var(x, 0)
tensor([ 9.4876,  8.6519,  8.4050,  9.8280, 10.0146,  8.6054,  7.0800,  8.6111,
         7.7851,  8.5604])

バッチノーマライズをかけてみよう

batch_norm=nn.BatchNorm1d(input_features)
y = batch_norm(x)

平均

torch.mean(y, 0)
tensor([ 1.9073e-08,  5.2452e-08, -4.7684e-09,  3.8743e-08, -3.8147e-08,
         4.1723e-08, -7.8678e-08, -5.9605e-08,  5.7220e-08,  4.2915e-08],
       grad_fn=<MeanBackward1>)

⇒ ほぼ0だ。

分散

torch.var(y, 0)
tensor([1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101,
        1.0101], grad_fn=<VarBackward1>)

⇒ ほぼ1だ。

注意点

入力を色々変えて動作させて気づいた注意点を列挙する。

入力データが1件の場合は計算エラーとなる

1件の場合は以下のエラーが出力される。

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 10])

そもそも1件の場合は、平均はそのデータ自身であり、分散は0になるため、計算する意味はない。

列のデータが全て同じ場合は、当然分散は0になる。

平均0にするために、列内の値は全て0になり、その結果当然ながら分散も1ではなく0になる。

入力データが少ない場合は分散が1にならない

データが3件の場合は1.5、10件の場合は1.111 などデータサイズが大きくなるにつれ1になる。詳しく掘り下げてはいないが、計算式によるものだと思われるのでドキュメントを見てほしい。

その他

入力直後にバッチノーマライズをやれば、入力データの正規化に使えるのでは?と思って調べたところ以下のQ&Aが見つかった

サンプル全体の平均と分散を一度計算する方が簡単で効率的とあるが、確かにその通りである。
確かにその通りだが面倒なんだモン!

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2