##背景
Data Augmentationした後の画像を表示したい!
と思って実装してみました。
Data Augmentationとは、1枚の画像を水増しする技術であり、以下のような操作を加えます。
- Random Crop(画像をランダムに切り取る)
- Random Horizontal Flip(画像を一定の確率で左右反転する)
- Random Erasing(画像の一部にランダムにノイズを付加する)
- Random Affine(画像をランダムに拡大・縮小・回転する)
この他にもいろいろあります。
##実装
今回は、CIFAR-10の訓練画像データセットを読み込んで、transformsにRandomHorizontalFlipとRandomErasingを組み込んでみました。
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt
#画像の読み込み
batch_size = 100
train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.RandomErasing(p=0.5, scale=(0.02, 0.4), ratio=(0.33, 3.0))]))
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),]))
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)
def image_show(data_loader,n):
#Augmentationした画像データを読み込む
tmp = iter(data_loader)
images,labels = tmp.next()
#画像をtensorからnumpyに変換
images = images.numpy()
#n枚の画像を1枚ずつ取り出し、表示する
for i in range(n):
image = np.transpose(images[i],[1,2,0])
plt.imshow(image)
plt.show()
image_show(train_loader,10)
image_show関数がAugmentation後の画像を表示する関数です。
iter()により、DataLoaderからミニバッチ1つ分を取得します。
そして、.next()により画像データをimagesに、ラベルをlabelsに格納します。
images = images.numpy()では、画像データをテンソルからnumpyに変換しています。
この時点でimagesは**[バッチサイズ, チャンネル数, 幅, 高さ]という構造になっていますが、matplotlibのpyplotで画像を表示するには[幅, 高さ, チャンネル数]**とする必要があります。
よって、np.transposeをつかって変形しています。
##実行結果例
左右反転されていたりRandom Erasingでノイズが付加されていたりすることが確認できました。