1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchで、BatchNorm・LayerNormの出力が自前で計算した値と一致するかを確認する

Last updated at Posted at 2024-06-14

はじめに

PyTorchで、BatchNormとLayerNormの出力が自前で計算した値と一致するかを確認してみました。

比較の結果、いずれも1e-7のオーダーで一致しました。また、BatchNormに関しては、1回目と2回目の反復実行でも移動平均と分散が自前の計算値と完全に一致することが確認できました。

  • BatchNormの分散の算出時は、Nで割る形、移動分散の算出時は、N-1で割る形(不偏分散)。
  • LayerNormの分散の算出時は、Nで割る形。

と分かりました。

以降、計算の詳細を記載します。

PyTorch BatchNorm・LayerNormの動作を確認する

BatchNorm・LayerNormの出力値が、自前で計算した値と一致するか確認。

結果: BatchNorm

  • 比較結果、自前で計算した値と1e-7のオーダーで一致。float型の計算精度は約7桁の有効数字なので、良好。
  • 1回目、2回目、反復して実行した時も、移動平均・分散は、自前で計算した値と完全に一致、良好。

結果: LayerNorm

  • 比較結果、自前で計算した値と1e-7のオーダーで一致。float型の計算精度は約7桁の有効数字なので、良好。

Libs & Paths

import numpy as np
import torch
import torch.nn as nn
!pip show torch
Name: torch
Version: 2.2.2
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: D:\...\WPy64-31230\python-3.12.3.amd64\Lib\site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: accelerate, botorch, kornia, lightning, linear-operator, pyro-ppl, pytorch-lightning, torchaudio, torchmetrics, torchvision

BatchNormの動作を確認する

BatchNorm出力値 (1回目)

torch.nn.BatchNorm2dの計算

bn1 = nn.BatchNorm2d(3, affine=False)
# affine=False: スケーリングとバイアスは使用せず、標準のレイヤー正規化を実施します。

# 入力
x = torch.randn((5, 3, 2, 2)) # n_batch x ch x H x W
# 出力
y = bn1(x)
y = y.detach() # 勾配計算なしへ変換
# 一部を見る
print(y.shape)
print(y[0, 0])
torch.Size([5, 3, 2, 2])
tensor([[ 0.3026, -0.5926],
        [-0.5748, -0.7544]])

自前で計算

# 自前で計算
N = x.size(0) * x.size(2) * x.size(3)
x_ave = x.sum(dim=(0,2,3), keepdim=True) / N
x_var = ((x - x_ave)**2).sum(dim=(0,2,3), keepdim=True) / N
eps = bn1.eps
y2 = (x - x_ave) / (torch.sqrt(x_var + eps))

# 一部を見る
print(y2[0, 0])
tensor([[ 0.3026, -0.5926],
        [-0.5748, -0.7544]])
# 比較
print((y2 - y)[0, 0])
print((y2 - y).abs().max())
tensor([[0.0000e+00, 1.1921e-07],
        [5.9605e-08, 1.1921e-07]])
tensor(2.3842e-07)

比較結果、1e-7のオーダーで一致する。float型の計算精度は約7桁の有効数字なので、良好。

生成AIへ質問: pytorchのfloat型の計算精度はどれぐらい?
PyTorchで使用されるfloat型(具体的にはtorch.float32)の計算精度は、32ビット浮動小数点精度(単精度)です。この形式では、約7桁の有効数字が保持されます。これは、数値計算において非常に広く使用されている標準的な精度です。

移動平均・分散 (1回目)

torch.nn.BatchNorm2dの計算

bn1.running_var
tensor([0.9739, 1.0061, 1.0119])

自前で計算

# 初期値
x_var_running_prev = torch.ones(3)

# 自前で計算
x_var2 = ((x - x_ave)**2).sum(dim=(0,2,3), keepdim=True) / (N - 1) # 不偏分散
# x_var2 = torch.var(x, dim=(0,2,3), keepdim=True, unbiased=True) # 上記と同じ値になる
r = bn1.momentum
x_var_running = x_var_running_prev * (1 - r) + x_var2.view(-1) * r

print(x_var_running)
tensor([0.9739, 1.0061, 1.0119])
# 比較
print(x_var_running - bn1.running_var)
print((x_var_running - bn1.running_var).abs().max())
tensor([0., 0., 0.])
tensor(0.)

比較結果、完全に一致する。

BatchNorm出力値 (2回目)

torch.nn.BatchNorm2dの計算

# 入力
x = torch.randn((5, 3, 2, 2)) # n_batch x ch x H x W
# 出力
y = bn1(x)
y = y.detach() # 勾配計算なしへ変換
# 一部を見る
print(y.shape)
print(y[0, 0])
torch.Size([5, 3, 2, 2])
tensor([[ 0.2065, -0.2257],
        [-0.1923, -0.8799]])

自前で計算

# 自前で計算
N = x.size(0) * x.size(2) * x.size(3)
x_ave = x.sum(dim=(0,2,3), keepdim=True) / N
x_var = ((x - x_ave)**2).sum(dim=(0,2,3), keepdim=True) / N
eps = bn1.eps
y2 = (x - x_ave) / (torch.sqrt(x_var + eps))

# 一部を見る
print(y2[0, 0])
tensor([[ 0.2065, -0.2257],
        [-0.1923, -0.8799]])
# 比較
print((y2 - y)[0, 0])
print((y2 - y).abs().max())
tensor([[ 2.9802e-08,  0.0000e+00],
        [ 0.0000e+00, -5.9605e-08]])
tensor(1.1921e-07)

比較結果、2回目も、1e-7のオーダーで一致する。float型の計算精度は約7桁の有効数字なので、良好。

移動平均・分散 (2回目)

torch.nn.BatchNorm2dの計算

bn1.running_var
tensor([0.9504, 0.9933, 0.9971])

自前で計算

# 前回値
x_var_running_prev = x_var_running

# 自前で計算
x_var2 = ((x - x_ave)**2).sum(dim=(0,2,3), keepdim=True) / (N - 1) # 不偏分散
# x_var2 = torch.var(x, dim=(0,2,3), keepdim=True, unbiased=True) # 上記と同じ値になる
r = bn1.momentum
x_var_running = x_var_running_prev * (1 - r) + x_var2.view(-1) * r

print(x_var_running)
tensor([0.9504, 0.9933, 0.9971])
# 比較
print(x_var_running - bn1.running_var)
print((x_var_running - bn1.running_var).abs().max())
tensor([0., 0., 0.])
tensor(0.)

比較結果、2回目も、完全に一致する。

LayerNormの動作を確認する

LayerNorm出力値 (1回目)

torch.nn.LayerNormの計算

ln1 = nn.LayerNorm([3, 2, 2], elementwise_affine=False)
# elementwise_affine=False: スケーリングとバイアスは使用せず、標準のレイヤー正規化を実施します。

# 入力
x = torch.randn((5, 3, 2, 2)) # n_batch x ch x H x W
# 出力
y = ln1(x)
y = y.detach() # 勾配計算なしへ変換
# 一部を見る
print(y.shape)
print(y[0, 0])
torch.Size([5, 3, 2, 2])
tensor([[-0.9299,  0.5133],
        [-1.9721, -0.6337]])

自前で計算

# 自前で計算
N = x.size(1) * x.size(2) * x.size(3)
x_ave = x.sum(dim=(1,2,3), keepdim=True) / N
x_var = ((x - x_ave)**2).sum(dim=(1,2,3), keepdim=True) / N
eps = ln1.eps
y2 = (x - x_ave) / (torch.sqrt(x_var + eps))

# 一部を見る
print(y2[0, 0])
tensor([[-0.9299,  0.5133],
        [-1.9721, -0.6337]])
# 比較
print((y2 - y)[0, 0])
print((y2 - y).abs().max())
tensor([[0., 0.],
        [0., 0.]])
tensor(4.7684e-07)

比較結果、1e-7のオーダーで一致する。float型の計算精度は約7桁の有効数字なので、良好。

生成AIへ質問: pytorchのfloat型の計算精度はどれぐらい?
PyTorchで使用されるfloat型(具体的にはtorch.float32)の計算精度は、32ビット浮動小数点精度(単精度)です。この形式では、約7桁の有効数字が保持されます。これは、数値計算において非常に広く使用されている標準的な精度です。

LayerNorm出力値 (2回目)

torch.nn.LayerNormの計算

# 入力
x = torch.randn((5, 3, 2, 2)) # n_batch x ch x H x W
# 出力
y = ln1(x)
y = y.detach() # 勾配計算なしへ変換
# 一部を見る
print(y.shape)
print(y[0, 0])
torch.Size([5, 3, 2, 2])
tensor([[ 2.1837,  0.0087],
        [-0.0279,  0.5970]])

自前で計算

# 自前で計算
N = x.size(1) * x.size(2) * x.size(3)
x_ave = x.sum(dim=(1,2,3), keepdim=True) / N
x_var = ((x - x_ave)**2).sum(dim=(1,2,3), keepdim=True) / N
eps = ln1.eps
y2 = (x - x_ave) / (torch.sqrt(x_var + eps))

# 一部を見る
print(y2[0, 0])
tensor([[ 2.1837,  0.0087],
        [-0.0279,  0.5970]])
# 比較
print((y2 - y)[0, 0])
print((y2 - y).abs().max())
tensor([[0., 0.],
        [0., 0.]])
tensor(1.1921e-07)

比較結果、2回目も、1e-7のオーダーで一致する。float型の計算精度は約7桁の有効数字なので、良好。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?