はじめに
PyTorchで画像を扱っている際,tochvisionのTransformsにあるToTensor関数(1)って何をしているのかが気になったので調べてまとめておこうと思います.
要約
ToTensor関数の動きを理解しnumpyと相互変換できるようになる
sample code
ここ(2)のコードを参考にしながら,numpyで画像を読み込んだと仮定してnumpy -> tensor -> numpy
に戻してみます.ダミー画像の大きさは$(W,H,C)=(4,5,1)$とします.また,動作確認のみなのため,ToTensor()
と同じ機能を持つimport torchvision.transforms.functional.to_tensor()
を使用しています.
import torchvision.transforms.functional as TF
import numpy as np
import torch
gray_image = np.asarray(torch.randint(0,255,(4, 5, 1), dtype=torch.uint8))
# 内部で gray_image = gray_image/255 してる.255は輝度の最大値
# numpy to tensor [0,1]の閉区間で正規化
nomalized_tensor_gray_image = TF.to_tensor(gray_image)
print(tensor_gray_image)
# tensor to numpy
nomalized_numpy_gray_image = np.asarray(nomalized_tensor_gray_image)
# (C,W,H)を(W,H,C)に変換するのと,[0,1]から[0,255]に直す.
numpy_gray_image = nomalized_numpy_gray_image.transpose(1,2,0) * 255
# データの形をもとに戻す np.float32 -> np.uint8
numpy_gray_image = numpy_gray_image.astype(np.uint8)
# もとに戻ったか確認
print(gray_image == numpy_gray_image)
ToTensorの役割
関数名だけを見るとnumpyやPIL Image(W,H,C)をTensor(C,W,H)に変換するだけだと思っていました.
しかし,実際の説明を見ても
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
と[0,255]を[0,1.0]に変換しているよと言ってます(データの型もuint8
からfloat32
に代わっています).modeでは,画像の種類(RGBやGrayScale)なども選択できます.
なぜ正規化する?
正規化をする理由としては,[0,255]だとデータのばらつきが大きく画像の特徴が俯瞰しにくいことが挙げられます.また,読み込まれたときの値が,なぜ[0,255]かは,画像の輝度に由来しています.GrayScale画像で考えると,一番暗い場所を0
,一番明るい場所が255
になります.RGBの画像でも輝度の概念は変わらないので知っておくと何かと便利です.
計算機としては学習の際,浮動小数点を扱うため,丸め込み誤差などを考えると,なるべく小数桁を多くとりたい意図があると考えられます.
おわりに
もともとは,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
ってなんでこんな値設定しているんだろうから始まった調べものですが,ToTensorも意外とこじゃれたことをしていることがわかり,書き残しました.感想や,議論があればコメントに書いていただけると幸いです.
参考文献
(1) Torchvision doc: Conversion Transforms
(2) Understanding transform.Normalize()