1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

torchvision.transformsでランダムに画像を90°単位で回転させる

Posted at

はじめに

torchvision.transformsで画像を90°単位でランダム回転させたい!

そう思ったのが事の始まりです。

RandomRotationじゃダメなの?

torchvision.transformsでランダム回転をさせるにはRandomRotationがあるのですが、90°単位で回転させるとなるとちょっと痒い所に手が届かない感じ。

主に次の二点が問題。

  1. 指定した角度の範囲内でランダムに回転させるものとなっていること

    角度の範囲指定を(90, 90)とかにすればできなくはないのですが、90°, 180°, 270°でそれぞれ用意する必要があるためあまりスマートではないですよね。

  2. 回転対象がTensor型の場合、その配列がPILで読み込める(例えば配列中身のデータ型がint8である)ものでないとエラーになってしまうこと

    例えば地理空間データとかを画像っぽくCNNで扱うとき(画素値がfloat型だったり、そもそもチャンネル数が1や3でなかったり)とかエラーになってしまうのです。。。

ということで、そういったものにも柔軟に対応できる自作のtransformを実装してみました。

実装

以下が実装したものです。

使い方は後半に用意しています。

import random
import torch
from PIL import Image
from torchvision import transforms

class RandomRotation90:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, x):
        if random.random() < self.p:
            i = random.randint(1, 3)

        if isinstance(x, Image.Image):
            x = transforms.RandomRotation((90*i, 90*i), expand=True)(x)

        elif isinstance(x, torch.Tensor):
            x = torch.rot90(x, i, [1, 2])

        else:
            raise TypeError(f'{type(x)} is unexpected type.')

        return x

動作確認

scilit-imageのサンプルデータに用意がある猫のチェルシー君を回転させて動作を見て行きます。

import skimage
origin = skimage.data.chelsea()
origin = Image.fromarray(origin)
origin

image.png

PIL

まずはPILで読み込まれた画像の回転。

引数pは回転をさせる確率です。

今回は動作確認なので100%回転するよう1.0を指定しています(以後の動作確認も同様)。

transformer = RandomRotation90(p=1.0)
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')
rotated

## 出力: <class 'PIL.Image.Image'> >> <class 'PIL.Image.Image'>

image.png

Tensor

# origin を Tensor型に変換
tensor = transforms.ToTensor()(origin)

# 回転
transformer = RandomRotation90(p=1.0)
rotated = transformer(tensor)
print(f'{type(tensor)} >> {type(rotated)}')

# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated

## 出力: <class 'torch.Tensor'> >> <class 'torch.Tensor'>

image.png

パイプライン化

torchvision.Composeでパイプラインに組み込むこともできます。

ここでは先ほどTensor型の動作確認のときに、PIL画像をTensor型に変換していたところをパイプラインで一気に処理が流れるようにします。

transformer = transforms.Compose([      
    transforms.ToTensor(),                         
    RandomRotation90(p=1.0)        
])
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')

# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated

## 出力: <class 'PIL.Image.Image'> >> <class 'torch.Tensor'>

image.png

numpy

ToTensor()をかませることでnumpyでもOK!

# numpy配列で読み込み直す
origin = skimage.data.chelsea()

# 回転
transformer = transforms.Compose([      
    transforms.ToTensor(),                         
    RandomRotation90(p=1.0)        
])
rotated = transformer(origin)
print(f'{type(origin)} >> {type(rotated)}')

# 表示するために再度PILに変換
rotated = transforms.ToPILImage()(rotated)
rotated

## 出力: <class 'numpy.ndarray'> >> <class 'torch.Tensor'>

image.png

最後に

ここまでご覧いただきありがとうございます。
「実はこういったやり方あるよ」とかありましたらコメントいただけると幸いです。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?