torchvision.transforms
transforms.ToTensor
torchvision.transforms.ToTensor
は画像ファイルから読み込んだNumPy
やPillow
形式の配列をPyTorch
形式に変換するtorchvision
のクラスです。torchvision.transforms.ToTensor
の処理は下記のようなプログラムを動かすことで確認することができます。
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
file_path = "images/plane.jpg"
img = Image.open(file_path).convert("RGB")
img_numpy = np.array(img)
img_transformed = transforms.ToTensor()(img)
print(type(img))
print(type(img_numpy))
print(type(img_transformed))
print("=====")
print(np.min(img_numpy), np.max(img_numpy))
print(img_transformed.min().item(), img_transformed.max().item())
・実行結果
<class 'PIL.Image.Image'>
<class 'numpy.ndarray'>
<class 'torch.Tensor'>
=====
0 255
0.0 1.0
実行結果より、torchvision.transforms.ToTensor
によって0〜255の整数が0.0〜1.0の小数に変換されることが確認できます。値の変換にあたっての対応については下記のプログラムを実行することで確認できます。
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
file_path = "images/plane.jpg"
output_path = "figure/hist.png"
img = Image.open(file_path).convert("RGB")
img_numpy = np.array(img)
img_transformed = transforms.ToTensor()(img)
fig, ax= plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(img_numpy.reshape([-1, 1]), bins=20)
ax[1].hist(img_transformed.numpy().reshape([-1, 1]), bins=20)
plt.savefig(output_path)
実行結果よりtorchvision.transforms.ToTensor
を用いた処理では値の正規化などは行われていないことが確認できます。