概要
DeepLearningのタスクの1つであるセマンティックセグメンテーション(Semantic Segmentation)では、分類や検出のタスクと異なって、教師データが画像形式になっている。そのためデータ拡張する場合(クロップや反転など)、入力画像と教師データそれぞれに同じように画像処理を行う必要がある。
この記事では入力画像と教師データの両方に同様のランダムなデータ拡張を実行する方法を紹介する記事。
セマンティックセグメンテーションとは
セマンティックセグメンテーションについては以下が参考になります。
U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)
今回はサンプル画像としてVOCデータセットの画像を使用する。
解像度はどちらも500x281になっている。
ランダムなデータ拡張
今回はPyTorchで予め用意されているtorchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)
を使って画像に対してクロップを実行する。
関数の中では、乱数でクロップする位置を決めて、指定したsize
でクロップしている。(最終的には内部でtorchvision.transforms.functional.crop(img, i, j, h, w)
がコールされている。)
詳細な使い方やパラメータについてはPyTorchのリファレンスを参照してください。
PyTorch TORCHVISION.TRANSFORMS
課題
torchvision.transforms.RandomCrop
は内部で乱数を発生させているため、実行するたびに結果が異なってしまう。
よって、以下のように実行すると入力画像と教師画像が異なる位置がクロップされてしまう。
from PIL import Image
from torchvision import transforms
trans_crop = transforms.RandomCrop((224,224))
img = Image.open(img_path) # 入力画像
target = Image.open(target_path) # 教師画像
img = trans_crop(img) # 入力画像を(224,224)でランダムクロップ
target = trans_crop(target) # 教師画像を(224,224)でランダムクロップ
img.show()
target.show()
このように入力画像と教師画像が一致しないため学習ができなくなってしまう。
解決策1 乱数シードを固定する
PyTorchのクロップ関数の内部ではrandom.randint()
で乱数を発生させているので、random.seed()
を使って乱数シードを設定すれば同じ結果が得られる。
from PIL import Image
from torchvision import transforms
import random
trans = transforms.RandomCrop((224,224))
img = Image.open(img_path)
target = Image.open(target_path)
seed = random.randint(0, 2**32) # 乱数で乱数シードを決定
random.seed(seed) # 乱数シードを固定
img = trans(img)
random.seed(seed) # こっちでも乱数シードを固定
target = trans(target)
img.show()
target.show()
この実装のよくない点はPyTorchの中で「random.randint()
で乱数を発生させている」ということを前提としていること。
PyTorchの中で乱数の発生のさせ方が変わると急に上手く動作しなくなったりする。
解決策2 transforms.RandomCrop.get_params(img, output_size))を使う
transforms.RandomCrop.get_params(img, output_size))
は乱数で決めたクロップする位置とサイズを返してくれる関数。
torchvision.transforms.RandomCrop
の中でもこの関数でクロップ位置を決めた後、torchvision.transforms.functional.crop(img, i, j, h, w)
でクロップしている。
なので、これと同じような処理の流れを自分で実装すれば解決できる。
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
trans = transforms.RandomCrop((224,224))
img = Image.open(img_path)
target = Image.open(target_path)
# クロップ位置を乱数で決定
i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(224,224))
img = tvf.crop(img, i, j, h, w) # 入力画像を(224,224)でクロップ
target = tvf.crop(target, i, j, h, w) # 教師画像を(224,224)でクロップ