LoginSignup
8
9

More than 3 years have passed since last update.

[PyTorch]セグメンテーションのためのDataAugmentation

Last updated at Posted at 2020-08-01

0.この記事の対象者

  1. PyTorchを使って画像セグメンテーションを実装する方
  2. DataAugmentationでデータの水増しをしたい方
  3. 対応するオリジナル画像とマスク画像に全く同じ処理を施したい方
  4. 特に自前のデータセット (torchvision.datasetsにないデータ)を使用する方

1 概要

主に, 教師ありまたは半教師ありでのセグメンテーション用データセットを想定

  • 自作データセットクラスの中にDataAugmentation処理を記述
  • 対応するオリジナル画像とマスク画像の両方に全く同じ処理を実行
    • 「クロップする位置」, 「角度」, 「反転するか否か」を一致させる
  • 画像ペア(=オリジナル画像+マスク画像)毎にランダム性のある処理を実行
    • ただし, 上述の通り画像ペア内の処理は一致

2 問題点

問題のケースをみる前に, まずは問題ないケースを考えてみます

2.1 問題ないケース(物体クラス認識など)

PyTorchでDataAugmentationする際には, 通常以下のように変換を定義して


transform = torchvision.transforms.Compose([
    # 角度degreesだけ回転
    transforms.RandomRotation(degrees),
    # 水平方向に反転
    transforms.RandomHorizontalFlip(),
    # 垂直方向に反転
    transforms.RandomVerticalFlip()
])

んでデータセットの引数に入れてやります


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform
)

おそらく物体クラス認識などではこれで問題ないでしょう
理由は教師データが画像じゃないので, オリジナル画像さえ加工してやればいいため

2.2 問題のケース(セグメンテーションなど)

お次に問題となるケース
先ほどのケースとの違いは教師データが画像として与えられている点です


transform = torchvision.transforms.Compose([
    # 角度degreesだけ回転
    transforms.RandomRotation(degrees),
    # 水平方向に反転
    transforms.RandomHorizontalFlip(),
    # 垂直方向に反転
    transforms.RandomVerticalFlip()
])

んでデータセットの引数に入れてやります


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform, target_transform=transform
)

ただし, これだとHogeDatasetからデータを取得する際, オリジナル画像とマスク画像になされる変換が対応づいたものとはなりません
例) オリジナル画像 : 90度回転, マスク画像 : 270度回転
これでは, データを水増ししても教師データとしては機能しません.
引数target_transformよ, お前はなぜ存在するんだ?ってなりますが, こいつの存在意義はおそらくマスク画像に対してもtorchvision.transforms.Resize()torchvision.transforms.ToTensor()のような(ランダム性がない)加工を施すためにあるんじゃないかと思います

3 解決策

ということで, オリジナル画像と同じ加工をマスク画像に施すにはどうすればいいのか
解決策としては, 以下のような自作のDatasetクラスを作る方法が考えられる

HogeDataset.py
import os
import glob
import torch
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
from PIL import Image

DATA_PATH = '[オリジナル画像のディレクトリパス]'
MASK_PATH = '[マスク画像のディレクトリパス]'
TRAIN_NUM = [訓練データ数]

class HogeDataset(torch.utils.data.Dataset):
    def __init__(self, transform = None, target_transform = None, train = True):
        # transform と target_transform はテンソル化などのランダム性のない変換
        self.transform = transform
        self.target_transform = target_transform


        data_files = glob.glob(DATA_PATH + '/*.[ファイル拡張子]')
        mask_files = glob.glob(MASK_PATH + '/*.[ファイル拡張子]')

        self.dataset = []
        self.maskset = []

        # オリジナル画像読み込み
        for data_file in data_files:
            self.dataset.append(Image.open(
                DATA_PATH + os.path.basename(data_file)
            ))

        # マスク画像読み込み
        for mask_file in mask_files:
            self.maskset.append(Image.open(
                MASK_PATH + os.path.basename(mask_file)
            ))

        # 訓練データとテストデータに分割
        if train:
            self.dataset = self.dataset[:TRAIN_NUM]
            self.maskset = self.maskset[:TRAIN_NUM]
        else:
            self.dataset = self.dataset[TRAIN_NUM+1:]
            self.maskset = self.maskset[TRAIN_NUM+1:]

        # Data Augmentation
        # ランダム性のある変換はここで行う
        self.augmented_dataset = []
        self.augmented_maskset = []
        for num in range(len(self.dataset)):
            data = self.dataset[num]
            mask = self.maskset[num]
            # ランダムクロップ
            for crop_num in range(16):
                # クロップ位置を乱数で決定
                i, j, h, w = transforms.RandomCrop.get_params(data, output_size=(250,250))
                cropped_data = tvf.crop(data, i, j, h, w)
                cropped_mask = tvf.crop(mask, i, j, h, w)

                # 回転(0, 90, 180, 270度)
                for rotation_num in range(4):
                    rotated_data = tvf.rotate(cropped_data, angle=90*rotation_num)
                    rotated_mask = tvf.rotate(cropped_mask, angle=90*rotation_num)

                    # 水平反転と垂直反転のどちらか
                    # 反転(水平方向)
                    for h_flip_num in range(2):
                        h_flipped_data = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_data)
                        h_flipped_mask = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_mask)

                    """    
                    # 反転(垂直方向)
                    for v_flip_num in range(2):
                        v_flipped_data = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_data)
                        v_flipped_mask = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_mask)
                    """

                        # DataAugmentation済みのデータを追加
                        self.augmented_dataset.append(h_flipped_data)
                        self.augmented_maskset.append(h_flipped_mask)

        self.datanum = len(self.augmented_dataset)

    # データサイズ取得メソッド
    def __len__(self):
        return self.datanum

    # データ取得メソッド
    # ランダム性の無い変換はここで行う
    def __getitem__(self, idx):
        out_data = self.augmented_dataset[idx]
        out_mask = self.augmented_maskset[idx]

        if self.transform:
            out_data = self.transform(out_data)

        if self.target_transform:
            out_mask = self.target_transform(out_mask)

        return out_data, out_mask

やっていることは単純で, __init__()の内部でDataAugmentationしてやります
その際に各画像ペアについて

  • ランダムクロップの位置を固定する
  • 固定の角度で回転操作を行う(×4)
  • 水平反転する場合, しない場合
  • (垂直反転する場合, しない場合)

の全ての場合の加工処理を網羅的に行います

一応, こんな感じでオリジナル画像と全く同じ処理をマスク画像に施してDataAugmetationできます
[補足]回転と反転を組み合わせると重複が生じることがあるので, 反転処理は水平か垂直のどちらかだけ使用することをお勧めします!!

4 使い方

3の自作Datasetクラスを使ってみる


import torch
import torchvision
import HogeDataset

BATCH_SIZE = [バッチサイズ]

# 前処理
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=0), 
    torchvision.transforms.ToTensor()
])

# 訓練データとテストデータの用意
trainset = HogeDataset.HogeDataset(
    train=True,
    transform=transform, 
    target_transform=target_transform
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

testset = HogeDatasets.HogeDatasets(
    train=False,
    transform=transform,
    target_transform=target_transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

8
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
9