albumentationsについて、自らのメモの意味も込めてブログを書いてみることにしました。data augmentation(データ拡張)については、人によって色々やり方あって、使うライブラリも千差万別だと思います。自分も最近まではtorchvisionとか主流だったのですが、ここに来てalbumentationsの便利さに惹かれ、専らこちらを使っています。
もしもこの記事がお役に立てた時は是非Qiitaのイイねボタンを押して下さい!
前段
Data Augmentation とは?
- 手元にあるオリジナルデータに何らかの処理を施して、データのバリエーションを増やすこと
- 学習時のデータを増やすことで、より汎用的で頑健なモデルを構築することが目的
- また、学習時のみならず推論時にもデータ拡張を実施することでよりバリエーションにとんだ推論結果を出すことが可能(Test Time Augmentation: TTA)
- 上記の結果をアンサンブルすることで、より高い精度が得られることがある
torchvisionについて
- PyTorchを使う際に、セットでよく使われる
- 上記のdata augmentationをする際にもかなり有用
albumentationsについて
- data augmentationでよく使われる機能が豊富に揃っている
- しかもかなり簡単なコードでかける
- Kerasでも使える
例えばalbumentationsのデフォルト機能を使えば、下の写真に天候補正も簡単に行うことができます。
ライブラリのインストールは pip install albumentations
もしくは pip install -U git+https://github.com/albu/albumentations
でいけます。
基本操作
まずはライブラリと今回サンプルで使うデータです。環境はGoogle Colaboratoryを想定しています。
import os, sys
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils.data import Dataset
import albumentations as albu
IMG_DIR = './image/'
MASK_DIR = './mask/'
img = cv2.imread(IMG_DIR + 'image.jpg')
mask = cv2.imread(MASK_DIR + 'mask.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
img_origin, mask_origin = img.copy(), mask.copy()
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(img)
axes[1].imshow(mask);
基本的な操作は以下の通りです。下記ではHorizontalFlip(左右反転)を適用しています。
img = albu.HorizontalFlip(p=1)(image=img)['image']
plt.figure(figsize=(8, 5))
plt.imshow(img);
複数の処理をまとめる時は、Composeで下記の様にします。例ではHorizontalFlipにプラスして、VerticalFlip(上下反転)を適用しています。
def get_augmentation():
train_transform = [
albu.HorizontalFlip(p=1),
albu.VerticalFlip(p=1),
]
return albu.Compose(train_transform)
transforms = get_augmentation()
img = img_origin.copy()
augmented = transforms(image=img)
img = augmented['image']
plt.figure(figsize=(8, 5))
plt.imshow(img);
複数の処理からランダムでどれか1つを適用させたい場合、OneOfを使い、以下の様に記述します。
def get_augmentation():
train_transform = [
albu.HorizontalFlip(p=1),
albu.VerticalFlip(p=1),
]
return albu.OneOf(train_transform) # <- Compose
transforms = get_augmentation()
img = img_origin.copy()
augmented = transforms(image=img)
img = augmented['image']
plt.figure(figsize=(8, 5))
plt.imshow(img);
セグメンテーションなどで、maskを教師ラベルとして返す時は以下のイメージです。(マスクのチャンネル数とかテンソル化とかは省略しています)
class QiitaDataset(Dataset):
def __init__(self, img_dir, mask_dir, transforms):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transforms = transforms
self.img_ids = sorted(os.listdir(self.img_dir))
self.mask_ids = sorted(os.listdir(self.mask_dir))
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_name = self.img_ids[idx]
img = cv2.imread(os.path.join(self.img_dir, img_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask_name = self.mask_ids[idx]
mask = cv2.imread(os.path.join(self.mask_dir, mask_name))
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
augmented = self.transforms(image=img, mask=mask)
img, mask = augmented['image'], augmented['mask']
return img, mask
dataset = QiitaDataset(IMG_DIR, MASK_DIR, transforms)
img, mask = dataset.__getitem__(0)
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(img)
axes[1].imshow(mask);
様々な変換とそのプロット
ここから先は実際に可視化して、各機能を確認していきたいと思います。調べてみると余りに沢山あり過ぎたので、いくつかピックアップしていきたいと思います。より詳しく知りたい方は公式を参照してください。
https://github.com/albu/albumentations
'''
プロットするための関数を用意
'''
def plot_image(img, mask, img_origin, mask_origin):
imgs = [img, mask, img_origin, mask_origin]
fig, axes = plt.subplots(2, 2, figsize=(15, 13))
for i, img in enumerate(imgs):
axes[i//2][i%2].imshow(img)
def plot_augmentation(img, mask, augmentation):
augmented = augmentation(image=img, mask=mask)
img_aug, mask_aug = augmented['image'], augmented['mask']
plot_image(img_aug, mask_aug, img, mask)
img = img_origin.copy()
mask = mask_origin.copy()
それではプロットしていきます。上段に変換後の画像データ。下段にオリジナルデータがプロットされます。
plot_augmentation(img, mask, albu.RandomRotate90(p=1))
plot_augmentation(img, mask, albu.RandomRotate90(p=1))
RandomRotateで画像を回転させています。画像によっては縦横の関係や上下の関係が、大切になってくる場合もあるので、線形変換を使うときは注意が必要です。
plot_augmentation(img, mask, albu.Blur(blur_limit=21, p=1))
お馴染みの平滑化フィルターでぼかしています。その際のカーネルサイズは3 ~ blur_limitで指定した数字の中からランダムでセレクトされます。因みにガウシアンフィルターもデフォルトで用意されています。
plot_augmentation(img, mask, albu.CLAHE(clip_limit=6.0, tile_grid_size=(8, 8), p=1))
上の画像はヒストグラム平滑化をすることで、コントラストが際立っています。左下の木の模様を見れば、その変化が分かりやすいかと思います。
CLAHEは画像全体の画素値をみて平滑化するのではなく、いくつかのタイル(上の例では 8 × 8 )に分割した上で、その中で平滑化を行なっています。こうすることで、近い画素値の範囲が色飛びしてしまうのを防いでいます。OpenCVのチュートリアルの例が非常に分かりやすいかと思います。
plot_augmentation(img, mask, albu.ChannelDropout(channel_drop_range=(1, 1), fill_value=0, p=1))
上の例では、一部のチャンネルを0 (= fill_value) にしています。
plot_augmentation(img, mask, albu.CenterCrop(height=256, width=256, p=1.0))
plot_augmentation(img, mask, albu.RandomCrop(height=256, width=256, p=1.0))
定番のCrop系も豊富に揃っています
plot_augmentation(img, mask, albu.IAASharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1))
上記では画像を先鋭化しています。オリジナルに比べ、よりシャープになっているのが分かります。先鋭化画像は元画像の平滑化フィルターと元画像の差分を元画像にオーバーレイすることで得られます。
- 元画像(青)と平滑化画像(赤)におけるエッジ部分
- 青 - 赤 の波形
- ①に②を足した後のエッジの波形
plot_augmentation(img, mask, albu.OpticalDistortion(distort_limit=3, shift_limit=0.2, interpolation=1, border_mode=4, p=1))
上記の様に歪曲収差の原理を使って、歪んだ画像にすることも可能です。
plot_augmentation(img, mask, albu.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=4, p=1))
上の画像は、その脇に関して、左右対称に像が映し出されているのが分かるかと思います。
plot_augmentation(img, mask, albu.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1, border_mode=4, p=1))
上の例だと、分かりにくいですが以下の様な変形をしています
画像出典:https://github.com/albu/albumentations
plot_augmentation(img, mask, albu.Resize(256, 256, interpolation=cv2.INTER_NEAREST, always_apply=False, p=1))
最後はResizeです。おそらくこれが一番使う機会が多いと思います。下記は代表的なinterpolation(補間)方法です。
interpolation | 説明 |
---|---|
cv2.INTER_NEAREST | 最近某補間。計算速度は早いが、滑らかなエッジがギザギザになって現れるジャギーが発生しやすい。 |
cv2.INTER_LINEAR | バイリニア補間(デフォルト)。周り四点の画素の重み付き平均。ニアレストネイバーの様なジャギーが目立たなくなるが、一方でエッジが鈍ってしまう傾向がある(平滑化フィルターのイメージ) |
cv2.INTER_AREA | 平均画素法。OpenCV公式によると縮小に向いている。 |
cv2.INTER_CUBIC | cv2.INTER_CUBIC。OpenCV公式によると拡大に向いている。処理が遅い。 |
cv2.INTER_LANCZOS4 | 8×8 の近傍領域を利用する Lanczos法(ランチョス)の補間。処理が遅い。 |
関係ないけど雑感
- 実際に画像解析に取り組む際には面倒臭がらず、可視化して観察することは大切だと思いました。