LoginSignup
131

More than 1 year has passed since last update.

画像データ拡張ライブラリ ~ albumentations ~

Last updated at Posted at 2019-11-19

 albumentationsについて、自らのメモの意味も込めてブログを書いてみることにしました。data augmentation(データ拡張)については、人によって色々やり方あって、使うライブラリも千差万別だと思います。自分も最近まではtorchvisionとか主流だったのですが、ここに来てalbumentationsの便利さに惹かれ、専らこちらを使っています。

 もしもこの記事がお役に立てた時は是非Qiitaのイイねボタンを押して下さい!

前段

Data Augmentation とは?

  • 手元にあるオリジナルデータに何らかの処理を施して、データのバリエーションを増やすこと
  • 学習時のデータを増やすことで、より汎用的で頑健なモデルを構築することが目的
  • また、学習時のみならず推論時にもデータ拡張を実施することでよりバリエーションにとんだ推論結果を出すことが可能(Test Time Augmentation: TTA)
  • 上記の結果をアンサンブルすることで、より高い精度が得られることがある

torchvisionについて

  • PyTorchを使う際に、セットでよく使われる
  • 上記のdata augmentationをする際にもかなり有用

albumentationsについて

  • data augmentationでよく使われる機能が豊富に揃っている
  • しかもかなり簡単なコードでかける
  • Kerasでも使える

例えばalbumentationsのデフォルト機能を使えば、下の写真に天候補正も簡単に行うことができます。

【オリジナル】
image.jpg

【雪】
sonw.jpg

【雨】
rain.jpg

【太陽光】
sun.jpg

ライブラリのインストールは pip install albumentations もしくは pip install -U git+https://github.com/albu/albumentations でいけます。

基本操作

まずはライブラリと今回サンプルで使うデータです。環境はGoogle Colaboratoryを想定しています。


import os, sys
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

from torch.utils.data import Dataset
import albumentations as albu


IMG_DIR = './image/'
MASK_DIR = './mask/'

img = cv2.imread(IMG_DIR + 'image.jpg')
mask = cv2.imread(MASK_DIR + 'mask.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

img_origin, mask_origin = img.copy(), mask.copy()

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(img)
axes[1].imshow(mask);

download.png

基本的な操作は以下の通りです。下記ではHorizontalFlip(左右反転)を適用しています。

img = albu.HorizontalFlip(p=1)(image=img)['image']

plt.figure(figsize=(8, 5))
plt.imshow(img);

download-1.png

 複数の処理をまとめる時は、Composeで下記の様にします。例ではHorizontalFlipにプラスして、VerticalFlip(上下反転)を適用しています。

def get_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=1),
        albu.VerticalFlip(p=1),
    ]
    return albu.Compose(train_transform)

transforms = get_augmentation()

img = img_origin.copy()
augmented = transforms(image=img)

img = augmented['image']
plt.figure(figsize=(8, 5))
plt.imshow(img);

download.png

 複数の処理からランダムでどれか1つを適用させたい場合、OneOfを使い、以下の様に記述します。

def get_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=1),
        albu.VerticalFlip(p=1),
    ]
    return albu.OneOf(train_transform)  # <- Compose

transforms = get_augmentation()

img = img_origin.copy()
augmented = transforms(image=img)

img = augmented['image']
plt.figure(figsize=(8, 5))
plt.imshow(img);

download.png

 セグメンテーションなどで、maskを教師ラベルとして返す時は以下のイメージです。(マスクのチャンネル数とかテンソル化とかは省略しています)

class QiitaDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transforms):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transforms = transforms
        self.img_ids = sorted(os.listdir(self.img_dir))
        self.mask_ids = sorted(os.listdir(self.mask_dir))

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_name = self.img_ids[idx]
        img = cv2.imread(os.path.join(self.img_dir, img_name))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask_name = self.mask_ids[idx] 
        mask = cv2.imread(os.path.join(self.mask_dir, mask_name))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        augmented = self.transforms(image=img, mask=mask)
        img, mask = augmented['image'], augmented['mask']
        return img, mask    


dataset = QiitaDataset(IMG_DIR, MASK_DIR, transforms)

img, mask = dataset.__getitem__(0)

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(img)
axes[1].imshow(mask);

download.png

様々な変換とそのプロット

 ここから先は実際に可視化して、各機能を確認していきたいと思います。調べてみると余りに沢山あり過ぎたので、いくつかピックアップしていきたいと思います。より詳しく知りたい方は公式を参照してください。
https://github.com/albu/albumentations

'''
プロットするための関数を用意
'''

def plot_image(img, mask, img_origin, mask_origin):
    imgs = [img, mask, img_origin, mask_origin]
    fig, axes = plt.subplots(2, 2, figsize=(15, 13))
    for i, img in enumerate(imgs):
        axes[i//2][i%2].imshow(img)

def plot_augmentation(img, mask, augmentation):
    augmented = augmentation(image=img, mask=mask)
    img_aug, mask_aug = augmented['image'], augmented['mask']
    plot_image(img_aug, mask_aug, img, mask)

img = img_origin.copy()
mask = mask_origin.copy()

 それではプロットしていきます。上段に変換後の画像データ。下段にオリジナルデータがプロットされます。

plot_augmentation(img, mask, albu.RandomRotate90(p=1))

download.png

plot_augmentation(img, mask, albu.RandomRotate90(p=1))

download.png

 RandomRotateで画像を回転させています。画像によっては縦横の関係や上下の関係が、大切になってくる場合もあるので、線形変換を使うときは注意が必要です。

plot_augmentation(img, mask, albu.Blur(blur_limit=21, p=1))

download.png

 お馴染みの平滑化フィルターでぼかしています。その際のカーネルサイズは3 ~ blur_limitで指定した数字の中からランダムでセレクトされます。因みにガウシアンフィルターもデフォルトで用意されています。

plot_augmentation(img, mask, albu.CLAHE(clip_limit=6.0, tile_grid_size=(8, 8), p=1))

download.png

 上の画像はヒストグラム平滑化をすることで、コントラストが際立っています。左下の木の模様を見れば、その変化が分かりやすいかと思います。
 CLAHEは画像全体の画素値をみて平滑化するのではなく、いくつかのタイル(上の例では 8 × 8 )に分割した上で、その中で平滑化を行なっています。こうすることで、近い画素値の範囲が色飛びしてしまうのを防いでいます。OpenCVのチュートリアルの例が非常に分かりやすいかと思います。

ヒストグラム その2: ヒストグラム平坦化

plot_augmentation(img, mask, albu.ChannelDropout(channel_drop_range=(1, 1), fill_value=0, p=1))

download.png

 上の例では、一部のチャンネルを0 (= fill_value) にしています。

plot_augmentation(img, mask, albu.CenterCrop(height=256, width=256, p=1.0))

download.png

plot_augmentation(img, mask, albu.RandomCrop(height=256, width=256, p=1.0))

download.png

 定番のCrop系も豊富に揃っています

plot_augmentation(img, mask, albu.IAASharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1))

download.png

 上記では画像を先鋭化しています。オリジナルに比べ、よりシャープになっているのが分かります。先鋭化画像は元画像の平滑化フィルターと元画像の差分を元画像にオーバーレイすることで得られます。

download.png

  1. 元画像(青)と平滑化画像(赤)におけるエッジ部分
  2. 青 - 赤 の波形
  3. ①に②を足した後のエッジの波形
plot_augmentation(img, mask, albu.OpticalDistortion(distort_limit=3, shift_limit=0.2, interpolation=1, border_mode=4, p=1))

download.png

 上記の様に歪曲収差の原理を使って、歪んだ画像にすることも可能です。

plot_augmentation(img, mask, albu.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=4, p=1))

download.png

 上の画像は、その脇に関して、左右対称に像が映し出されているのが分かるかと思います。

plot_augmentation(img, mask, albu.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1, border_mode=4, p=1))

download.png

 上の例だと、分かりにくいですが以下の様な変形をしています

画像出典:https://github.com/albu/albumentations

download.jpg

plot_augmentation(img, mask, albu.Resize(256, 256, interpolation=cv2.INTER_NEAREST, always_apply=False, p=1))

download.png

最後はResizeです。おそらくこれが一番使う機会が多いと思います。下記は代表的なinterpolation(補間)方法です。

interpolation 説明
cv2.INTER_NEAREST 最近某補間。計算速度は早いが、滑らかなエッジがギザギザになって現れるジャギーが発生しやすい。
cv2.INTER_LINEAR バイリニア補間(デフォルト)。周り四点の画素の重み付き平均。ニアレストネイバーの様なジャギーが目立たなくなるが、一方でエッジが鈍ってしまう傾向がある(平滑化フィルターのイメージ)
cv2.INTER_AREA 平均画素法。OpenCV公式によると縮小に向いている。
cv2.INTER_CUBIC cv2.INTER_CUBIC。OpenCV公式によると拡大に向いている。処理が遅い。
cv2.INTER_LANCZOS4 8×8 の近傍領域を利用する Lanczos法(ランチョス)の補間。処理が遅い。

関係ないけど雑感

  • 実際に画像解析に取り組む際には面倒臭がらず、ちょいちょい可視化して見るのは大切だと思いました。

追記(2021/09/27)

 プロ野球データの可視化サイト を作りました。まだまだクオリティは低いですが、今後少しずつバージョンアップさせていく予定です。野球好きの方は是非遊びに来てください⚾️

image.png

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
131