はじめに
torchvision.transformsで画像を90°単位でランダム回転させたい!
そう思ったのが事の始まりです。
RandomRotationじゃダメなの?
torchvision.transformsでランダム回転をさせるにはRandomRotationがあるのですが、90°単位で回転させるとなるとちょっと痒い所に手が届かない感じ。
主に次の二点が問題。
-
指定した角度の範囲内でランダムに回転させるものとなっていること
※
角度の範囲指定を(90, 90)とかにすればできなくはないのですが、90°, 180°, 270°でそれぞれ用意する必要があるためあまりスマートではないですよね。 -
回転対象がTensor型の場合、その配列がPILで読み込める(例えば配列中身のデータ型がint8である)ものでないとエラーになってしまうこと
※
例えば地理空間データとかを画像っぽくCNNで扱うとき(画素値がfloat型だったり、そもそもチャンネル数が1や3でなかったり)とかエラーになってしまうのです。。。
ということで、そういったものにも柔軟に対応できる自作のtransformを実装してみました。
実装
以下が実装したものです。
使い方は後半に用意しています。
import random
import torch
from PIL import Image
from torchvision import transforms
class RandomRotation90:
def __init__(self, p=0.5):
self.p = p
def __call__(self, x):
if random.random() < self.p:
i = random.randint(1, 3)
if isinstance(x, Image.Image):
x = transforms.RandomRotation((90*i, 90*i), expand=True)(x)
elif isinstance(x, torch.Tensor):
x = torch.rot90(x, i, [1, 2])
else:
raise TypeError(f'{type(x)} is unexpected type.')
return x
動作確認
scilit-imageのサンプルデータに用意がある猫のチェルシー君を回転させて動作を見て行きます。
import skimage
origin = skimage.data.chelsea()
origin = Image.fromarray(origin)
origin
PIL
まずはPILで読み込まれた画像の回転。
引数pは回転をさせる確率です。
今回は動作確認なので100%回転するよう1.0を指定しています(以後の動作確認も同様)。
transformer = RandomRotation90(p=1.0)
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')
rotated
## 出力: <class 'PIL.Image.Image'> >> <class 'PIL.Image.Image'>
Tensor
# origin を Tensor型に変換
tensor = transforms.ToTensor()(origin)
# 回転
transformer = RandomRotation90(p=1.0)
rotated = transformer(tensor)
print(f'{type(tensor)} >> {type(rotated)}')
# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated
## 出力: <class 'torch.Tensor'> >> <class 'torch.Tensor'>
パイプライン化
torchvision.Composeでパイプラインに組み込むこともできます。
ここでは先ほどTensor型の動作確認のときに、PIL画像をTensor型に変換していたところをパイプラインで一気に処理が流れるようにします。
transformer = transforms.Compose([
transforms.ToTensor(),
RandomRotation90(p=1.0)
])
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')
# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated
## 出力: <class 'PIL.Image.Image'> >> <class 'torch.Tensor'>
numpy
ToTensor()をかませることでnumpyでもOK!
# numpy配列で読み込み直す
origin = skimage.data.chelsea()
# 回転
transformer = transforms.Compose([
transforms.ToTensor(),
RandomRotation90(p=1.0)
])
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')
# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated
## 出力: <class 'numpy.ndarray'> >> <class 'torch.Tensor'>
最後に
ここまでご覧いただきありがとうございます。
「実はこういったやり方あるよ」とかありましたらコメントいただけると幸いです。