記事概要
torchvision.transforms の使い方メモです。なお、transforms.functional はカバーしていません。
torchvision.transforms とは
画像データの入力を加工してくれる callable なクラス群です。画像オブジェクトを渡すだけで前処理 (オーグメント/正規化等) を行ってくれる便利なクラスですが、PIL オブジェクトにのみ対応している transforms と torch.Tensor のみに対応している transforms が存在するため、適用順を間違えるとエラーが出ます。 どちらに対応したクラスかは 公式のリファレンス に載っているため、使う前にチェックするといいと思います。基本的には PIL で操作して、正規化だけは torch.Tensor にすることを覚えておけばよいです。
また、detection や segmentation など ground truth に座標を含むようなタスクではこれらのクラスを利用できない (イメージの変換はできるが、座標の変換はできない) ため、自分でクラスを作る必要があります。
PIL のみに対応したクラス
サイズ系
CenterCropFiveCropPadRandomAffineRandomCropRandomResizedCropResizeTenCrop
色系
ColorJitterGrayscaleRandomAffineRandomGrayscale
反転系
RandomHorizontalFlipRandomVerticalFlipTenCrop
回転系
RandomPerspectiveRandomRotation
汎用
RandomApplyRandomChoiceRandomOrder
torch.Tensor のみに対応したクラス
線形変換
LinearTransformation
正規化
Normalize
オーグメント
-
RandomErasing(https://arxiv.org/pdf/1708.04896.pdf)
その他
型変換
-
ToPILImage(numpy.ndarray or torch.Tensor -> PIL) -
ToTensor(numpy.ndarray or PIL -> torch.Tensor)
汎用
Lambda