0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

transforms.ToTensor()は何をしている?

Posted at

概要

torchvision.transforms.ToTensor()は、PyTorchで画像データ(PILなど)をTensorに変換するのによく見る関数です。しかし、「このメソッドは正規化もしてくれている」という誤解が広まっていることがあります。本記事では、この点を明確に解説します。

transforms.ToTensor()の役割

公式の説明によると、transforms.ToTensor()は、2つの操作を行うようです:

  1. 画像データ(PIL Imageもしくはnp.ndarray)をPyTorchのTensor型に変換する
    画像データが標準的なPythonライブラリ(PILなど)で扱われる形式から、PyTorchで扱えるテンソル形式に変換されます。変換後に(C, H, W)の次元になっていることにも注意が必要です。

image.png

  1. 画素値のスケールを0〜255から0〜1に変換する
    画像データの画素値は一般的に0〜255の範囲で表されますが、transforms.ToTensor()を使用すると、これを0〜1の範囲にスケーリングされます。

image.png

注意:正規化とは異なる

ToTensor()が行っているのは「画素値を0〜1にスケーリングする」という処理であり、これは厳密には「正規化」とは異なります。**正規化(Normalization)**とは、画像の平均と標準偏差に基づいてデータを変換する処理です。

具体的には、以下のような形で正規化は通常行われます:

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

※これらの数値[0.485, ...]は、ImageNetベースで計算された数値のようです。独自のデータセットを用いる場合は、独自の正規化のための係数を使用する必要があります

この操作では、画像の各チャンネルごとに平均値を引いて標準偏差で割ることで、ピクセルの値を正規化します。ToTensor()はこの処理を含まないため、追加でNormalize()を適用する必要があります。

まとめ

  • transforms.ToTensor()は画像をテンソルに変換し、画素値を0〜255から0〜1にスケールします
  • 正規化(Normalization)とは別の処理です。正規化を行いたい場合は、transforms.Normalize()を追加で使用する必要があります
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?