36
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[PyTorch]ToTensor()はどのように動いているんだろう

Last updated at Posted at 2020-06-19

はじめに

PyTorchで画像を扱っている際,tochvisionのTransformsにあるToTensor関数(1)って何をしているのかが気になったので調べてまとめておこうと思います.

要約

ToTensor関数の動きを理解しnumpyと相互変換できるようになる

sample code

ここ(2)のコードを参考にしながら,numpyで画像を読み込んだと仮定してnumpy -> tensor -> numpyに戻してみます.ダミー画像の大きさは$(W,H,C)=(4,5,1)$とします.また,動作確認のみなのため,ToTensor()と同じ機能を持つimport torchvision.transforms.functional.to_tensor()を使用しています.

import torchvision.transforms.functional as TF
import numpy as np
import torch

gray_image = np.asarray(torch.randint(0,255,(4, 5, 1), dtype=torch.uint8))
# 内部で gray_image = gray_image/255 してる.255は輝度の最大値
# numpy to tensor [0,1]の閉区間で正規化
nomalized_tensor_gray_image = TF.to_tensor(gray_image)
print(tensor_gray_image)

# tensor to numpy
nomalized_numpy_gray_image = np.asarray(nomalized_tensor_gray_image)
# (C,W,H)を(W,H,C)に変換するのと,[0,1]から[0,255]に直す.
numpy_gray_image = nomalized_numpy_gray_image.transpose(1,2,0) * 255
# データの形をもとに戻す np.float32 -> np.uint8
numpy_gray_image = numpy_gray_image.astype(np.uint8)
# もとに戻ったか確認
print(gray_image == numpy_gray_image)

ToTensorの役割

関数名だけを見るとnumpyやPIL Image(W,H,C)をTensor(C,W,H)に変換するだけだと思っていました.
しかし,実際の説明を見ても

Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8

In the other cases, tensors are returned without scaling.

と[0,255]を[0,1.0]に変換しているよと言ってます(データの型もuint8からfloat32に代わっています).modeでは,画像の種類(RGBやGrayScale)なども選択できます.

なぜ正規化する?

正規化をする理由としては,[0,255]だとデータのばらつきが大きく画像の特徴が俯瞰しにくいことが挙げられます.また,読み込まれたときの値が,なぜ[0,255]かは,画像の輝度に由来しています.GrayScale画像で考えると,一番暗い場所を0,一番明るい場所が255になります.RGBの画像でも輝度の概念は変わらないので知っておくと何かと便利です.
計算機としては学習の際,浮動小数点を扱うため,丸め込み誤差などを考えると,なるべく小数桁を多くとりたい意図があると考えられます.

おわりに

もともとは,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])ってなんでこんな値設定しているんだろうから始まった調べものですが,ToTensorも意外とこじゃれたことをしていることがわかり,書き残しました.感想や,議論があればコメントに書いていただけると幸いです.

参考文献

(1) Torchvision doc: Conversion Transforms
(2) Understanding transform.Normalize()

36
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?