概要
torchvision.transforms.ToTensor()
は、PyTorchで画像データ(PILなど)をTensorに変換するのによく見る関数です。しかし、「このメソッドは正規化もしてくれている」という誤解が広まっていることがあります。本記事では、この点を明確に解説します。
transforms.ToTensor()
の役割
公式の説明によると、transforms.ToTensor()
は、2つの操作を行うようです:
-
画像データ(PIL Imageもしくはnp.ndarray)をPyTorchのTensor型に変換する
画像データが標準的なPythonライブラリ(PILなど)で扱われる形式から、PyTorchで扱えるテンソル形式に変換されます。変換後に(C, H, W)の次元になっていることにも注意が必要です。
-
画素値のスケールを0〜255から0〜1に変換する
画像データの画素値は一般的に0〜255の範囲で表されますが、transforms.ToTensor()
を使用すると、これを0〜1の範囲にスケーリングされます。
注意:正規化とは異なる
ToTensor()
が行っているのは「画素値を0〜1にスケーリングする」という処理であり、これは厳密には「正規化」とは異なります。**正規化(Normalization)**とは、画像の平均と標準偏差に基づいてデータを変換する処理です。
具体的には、以下のような形で正規化は通常行われます:
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
※これらの数値[0.485, ...]は、ImageNetベースで計算された数値のようです。独自のデータセットを用いる場合は、独自の正規化のための係数を使用する必要があります
この操作では、画像の各チャンネルごとに平均値を引いて標準偏差で割ることで、ピクセルの値を正規化します。ToTensor()
はこの処理を含まないため、追加でNormalize()
を適用する必要があります。
まとめ
-
transforms.ToTensor()
は画像をテンソルに変換し、画素値を0〜255から0〜1にスケールします - 正規化(Normalization)とは別の処理です。正規化を行いたい場合は、
transforms.Normalize()
を追加で使用する必要があります