LoginSignup
0
0

事前知識

今回はPytorchとAlbumentationを用いて実装します。

  1. Epoch
  2. Mini-Batch
  3. Dataloader
  4. Dataset Class

Data Augmentationとは?

Data Augmentation(データ拡張)とは、モデルの学習に用いるデータを”増やす”手法で、下記のようなケースで便利です。

  1. 十分なデータが無いとき
  2. 学習用データの多様性を高めたいとき

実際には何をしているのか?

データ拡張を行う際、学習に用いるデータのサンプル数を実際に増やしているわけではありません。

何をしているのかというと、DataLoaderを用いてミニバッチごとにデータを取得する都度、ランダムで前処理を適応することで、モデルに送り付けるデータが若干変わるようにしています。

(上)データ拡張無しの場合、(下)データ拡張有りの場合
image.png

各EPOCHで、Dataloaderが取りに行くデータのサンプル自体は毎回変わらないが、モデルに最終的にわたるデータは毎回異なるため、DataAugmentationを用いない場合と比べて、一般化やオーバーフィットを防げるようになります。

注意点

注意点として、入力画像にデータ拡張を適用する場合は、正解用データにも同じ条件で適用する必要があります。

例えば、物体検出のモデルを学習するとします。
入力画像に対してランダムに画像を反転させるように処理を書いた場合、それに対応する物体検出のBounding Box(物体ここだよ!って示す箱のこと)の座標も反転した位置に変更しないとモデルの学習がうまくいきません。

実装

今回は、Albumentationを用いたDataAugmentationのパイプランの定義とDatasetClassの定義の仕方のみ共有しています。実際にDataloaderに載せる方法は、こちらを見ていただけたらわかるかと思います。

Part1 : データ拡張のパイプラインを定義

 #インポート
import albumentations as A
    
#Composeを用いて、どのような手順でデータ拡張を行うのかを定義できます。
transform = A.Compose(
    [
    # どのような処理をここで使用できるかは興味あれば調べてみてください
        A.Resize(256, 256),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

Part2: データセットクラス内でデータ拡張を適用

データセットクラスのgetitem()関数内でデータ拡張を適用する処理を書きます。
※AlbumentationはopenCVとnumpyの関数を裏で使用しているので、画像の読み込みもopenCVで行います。

from torch.utils.data import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, images_filenames, mask_filenames, transform=None):
        #すべての画像パス
        self.images_filenames = images_filenames
        #すべての正解用画像パス
        self.mask_filenames = mask_filenames
        #Data Augmentationのパイプライン
        self.transform = transform

    def __len__(self):
        return len(self.images_filenames)

    
    def __getitem__(self, idx):
        #画像パスを取得
        image_filename = self.images_filenames[idx]
    
        #画像を開く
        image = cv2.imread(image_filename)
    
        #openCVはBGR形式で画像を読み込むので、RGB形式に変換します。
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
        #正解用画像の読み込み
        mask_filename = self.mask_filenames[idx]
        mask = cv2.imread(mask_filename,cv2.IMREAD_UNCHANGED)
    
        #データ拡張を同じ条件で、入力用及び正解用データに対して適用します。
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
    
        #データ拡張済みのデータを返却
        return image, mask


Writed by F.K(20代・入社3年目)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0