LoginSignup
34
23

More than 3 years have passed since last update.

PyTorch 入力画像と教師画像の両方にランダムなデータ拡張を実行する方法

Last updated at Posted at 2019-01-13

概要

DeepLearningのタスクの1つであるセマンティックセグメンテーション(Semantic Segmentation)では、分類や検出のタスクと異なって、教師データが画像形式になっている。そのためデータ拡張する場合(クロップや反転など)、入力画像と教師データそれぞれに同じように画像処理を行う必要がある。

この記事では入力画像と教師データの両方に同様のランダムなデータ拡張を実行する方法を紹介する記事。

セマンティックセグメンテーションとは

セマンティックセグメンテーションについては以下が参考になります。
U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)

今回はサンプル画像としてVOCデータセットの画像を使用する。
解像度はどちらも500x281になっている。

・入力画像
2007_000032.jpg
・教師画像
2007_000032.png

ランダムなデータ拡張

今回はPyTorchで予め用意されているtorchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)を使って画像に対してクロップを実行する。

関数の中では、乱数でクロップする位置を決めて、指定したsizeでクロップしている。(最終的には内部でtorchvision.transforms.functional.crop(img, i, j, h, w)がコールされている。)

詳細な使い方やパラメータについてはPyTorchのリファレンスを参照してください。
PyTorch TORCHVISION.TRANSFORMS

課題

torchvision.transforms.RandomCropは内部で乱数を発生させているため、実行するたびに結果が異なってしまう。
よって、以下のように実行すると入力画像と教師画像が異なる位置がクロップされてしまう。

from PIL import Image
from torchvision import transforms

trans_crop = transforms.RandomCrop((224,224))

img = Image.open(img_path) # 入力画像
target = Image.open(target_path) # 教師画像

img = trans_crop(img) # 入力画像を(224,224)でランダムクロップ
target = trans_crop(target) # 教師画像を(224,224)でランダムクロップ

img.show()
target.show()

input.png
target.png

このように入力画像と教師画像が一致しないため学習ができなくなってしまう。

解決策1 乱数シードを固定する

PyTorchのクロップ関数の内部ではrandom.randint()で乱数を発生させているので、random.seed()を使って乱数シードを設定すれば同じ結果が得られる。

from PIL import Image
from torchvision import transforms
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

seed = random.randint(0, 2**32) # 乱数で乱数シードを決定

random.seed(seed) # 乱数シードを固定
img = trans(img)

random.seed(seed) # こっちでも乱数シードを固定
target = trans(target)

img.show()
target.show()

input2.png
target2.png

この実装のよくない点はPyTorchの中で「random.randint()で乱数を発生させている」ということを前提としていること。
PyTorchの中で乱数の発生のさせ方が変わると急に上手く動作しなくなったりする。

解決策2 transforms.RandomCrop.get_params(img, output_size))を使う

transforms.RandomCrop.get_params(img, output_size))は乱数で決めたクロップする位置とサイズを返してくれる関数。

torchvision.transforms.RandomCropの中でもこの関数でクロップ位置を決めた後、torchvision.transforms.functional.crop(img, i, j, h, w)でクロップしている。

なので、これと同じような処理の流れを自分で実装すれば解決できる。

from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

# クロップ位置を乱数で決定
i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(224,224))

img = tvf.crop(img, i, j, h, w) # 入力画像を(224,224)でクロップ
target = tvf.crop(target, i, j, h, w) # 教師画像を(224,224)でクロップ

image.png
image.png


34
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
34
23