PyTorchでデータの水増し(Data Augmentation)
PyTorchでデータを水増しをする方法をまとめます。PyTorch自体に関しては、以前ブログに入門記事を書いたので、よければ以下参照下さい。
注目のディープラーニングフレームワーク「PyTorch」入門
データ水増しを実施する理由や、具体例などは以下記事参照下さい。
フリー素材で遊びながら理解するディープラーニング精度向上のための画像データ水増し(Data Augmentation)手法
またこの記事は「Google Colaboratory(Google Colab)」で実行することを前提に書かれています。Google Colab自体に関してはこの記事では説明しません。知らない方は、以下記事を参照してみてください。
Google Colaboratoryを使えば環境構築不要・無料でPythonの機械学習ができて最高
この記事で使用したコードは、以下のノートブックにまとめています。
pytorch_data_preprocessing.ipynb
真ん中にある「Open in Colab」というアイコンをクリックすると、Google Colabで開いてそのまま実行することができます。
PyTorchでのデータの扱い
まず最初にPyTorchでのデータの扱いを確認していきます。
教師データのダウンロード
最初に教師データをダウンロードします。説明は省略します。
!git clone https://github.com/karaage0703/janken_dataset datasets
!rm -rf /content/datasets/.git
!rm /content/datasets/LICENSE
ディレクトリは、以下のような構成になっています。choki, gu, pa、それぞれのディレクトリに、チョキ、グー、パーの手の形の写真が入っています。
datasets
├── choki
├── gu
└── pa
以下のようにdataset_root_dir
を定義します。
dataset_root_dir = '/content/datasets'
データセットの作成
最初に、必要なライブラリをインポートしておきます。
import torch
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import PIL
ImageFolderを使って、フォルダの画像を dataset として読み込みます。
dataset = datasets.ImageFolder(root=dataset_root_dir)
データセットの確認
datasetはgetitemで中身を確認できます。(#以下は実行結果です)。
print(dataset.__getitem__(0))
print(dataset.__getitem__(100))
print(dataset.__getitem__(150))
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F11DB6DC160>, 0)
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F11DB6DCF28>, 1)
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F12297D2C50>, 2)
matplotlibで中身を確認する場合は以下です。
image_numb = 6 # 3の倍数を指定してください
for i in range(0, image_numb):
ax = plt.subplot(image_numb / 3, 3, i + 1)
plt.tight_layout()
ax.set_title(str(i))
plt.imshow(dataset[i][0])
torchvision.transforms
PyTorchではtransformsで、Data Augmentation含む様々な画像処理の前処理を行えます。
代表的な、左右反転・上下反転ならtransformsは以下のような形でかきます。
data_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
])
あとは、ImageFolderのtransformの引数に指定すれば、transformsで指定された画像処理をしたデータセットが定義されます。
dataset_augmentated = datasets.ImageFolder(root=dataset_root_dir, transform=data_transform)
データを確認してみます。
image_numb = 6 # 3の倍数を指定してください
for i in range(0, image_numb):
ax = plt.subplot(image_numb / 3, 3, i + 1)
plt.tight_layout()
ax.set_title(str(i))
plt.imshow(dataset_augmentated[i][0])
上下反転されています。
その他のtransformsの関数の実施例は、Google Colabのノートブックを参照下さい。Random Erasing等の手法も標準で実装されています。全てを知りたい場合は、公式ドキュメントを参照下さい。
albumentations 実装
albumentationsというData Augmentation用のライブラリをPyTorchで手軽に使う方法です。
最初に、以下コマンドでalbumentationsをインストールします。
! pip install albumentations
必要なライブラリをインポートします。
import albumentations as albu
import numpy as np
from PIL import Image
transformのときと同様に、ImageFolderでalbumentationでのデータ水増しを行いたいところですが、ちょっとテクニックが必要です。
以下のようにかけば、albumentationsの機能をImageFolderで簡単に使えます。
albu_transforms = albu.Compose([
albu.RandomRotate90(p=0.5),
albu.RandomGamma(gamma_limit=(85, 115), p=0.2),
])
def albumentations_transform(image, transform=albu_transforms):
if transform:
image_np = np.array(image)
augmented = transform(image=image_np)
image = Image.fromarray(augmented['image'])
return image
data_transform = transforms.Compose([
transforms.Lambda(albumentations_transform),
])
dataset_augmentated = datasets.ImageFolder(root=dataset_root_dir, transform=data_transform)
データの中身を確認してみます。
image_numb = 6 # 3の倍数を指定してください
for i in range(0, image_numb):
ax = plt.subplot(image_numb / 3, 3, i + 1)
plt.tight_layout()
ax.set_title(str(i))
plt.imshow(dataset_augmentated[i][0])
albumentationsの画像処理がされていることが分かります。
少し調べた感じではalbumentationsを使うときは、ImageFolderを使わずdatasetを独自実装する場合が多いようですが、ImageFolderで手軽に試したい場合に便利なテクニックです。
Albumentations参考情報
albumentationsにどのような機能があるかは、@Kazuhito さんが、GitHubで公開しているalbumentations-examplesのJupyter Notebookが非常に参考になります。
また @Kazuhito さんのJupyter NotebookをGoogle Colabで動くように改変したノートブックを以下で公開していますので、実際に自分の手で動かしたい人は参考にしてみて下さい。
albumentations_examples.ipynb(Google Colab対応版)
ひたすら試した例としては、以下が参考になりそうです。
Albumentationsのaugmentaitonをひたすら動かす
mixup
性能出ることで話題のデータ水増し手法mixupをPyTorchで使用する際は、以下のGitHubリポジトリが参考になりました。
mixupする方法とmixupした後のデータの確認方法の詳細は、Google Colabノートブック参照下さい。
pytorch_data_preprocessing.ipynb
Kerasの場合は、以下の記事が参考になりそうです。
まとめ
PyTorchでデータ水増し(Data Augmentation)する方法と、データの確認方法をまとめました。もっと便利な機能があったり、スマートな方法があったりしたら是非教えて下さい。
関連記事
TensorFlowのObject Detection APIのData Augmentationで何をやっているか動かして確認
変更履歴
- 2020/07/13 Albumentationsに関して追記