どうもエンジニアのirohasです。
最近さらにブームが巻き起こっているAI。
そのAI開発において開発手法として用いられている機械学習やディープラーニングにおいて、DataAugumentation(データ拡張)というのはすごく重要になります。
そこで今回は、前回実装したkerasのDataAugumentationとは別にPytorchでのDataAugumentationを実装してみたので、備忘録も兼ねて記事にしようと思います。
目次
1.はじめに
2.環境
3.Data Augumentationって何?
4.事前準備
5.実装
6.機能説明
7.まとめ
8.参考文献
1. はじめに
みなさんはPytorchというものをご存じでしょうか?
PyTorchは、コンピュータビジョンや自然言語処理で利用されているTorchを元に作られた、Pythonのオープンソースの機械学習ライブラリです。
最初はFacebookの人工知能研究グループAI Research lab(FAIR)により開発され、フリーでオープンソースのソフトウェアとしてAI開発において現在も広く使用されています。
PyTorchの機能:
・強力なGPUサポートを備えた(NumPyのような)テンソル演算
・自動微分エンジンの上に構築された深層学習トレーニング(torch.autograd)
・自動微分化や自動ベクトル化などの(Google JAX(英語版)のような)関数変換(functorch)
・実行時コンパイル (torch.jit)
・機械学習モデル形式「ONNX」へのエクスポート(torch.onnx)
今回はこのPytorchでのDataAugumentationを実装していきます。
2. 環境
PC: MacOS
言語: Python v3.10
ライブラリ:
[標準ライブラリ]
glob
[外部ライブラリ]
Pillow 9.1.0
torch 1.11.0
3. DataAugumentationって何?
前回のkerasでのData Augmentationの記事で説明しましたが、ここにも記載しておきます。
Data Augmentation(データ拡張)とは、学習用の画像データに対して「変換」を施すことでデータを水増しする手法です。
機械学習、特にディープラーニングによる画像認識では大量の画像データを必要とする場合が多いです。
時にその量は数万枚や数十万枚(またはそれ以上)に登る場合もあるので、人間の手でデータを集めたり、スクレイピングで集めるのも至難の業となってきます。
そこで活躍するのが上記のData Augmentation(データ拡張)という技術です。
Data Augmentation(データ拡張)では、一枚の画像を回転させたり、反転させたり、明るさを調整したりしてデータを増やすことができます。
4. 事前準備
PytorchでData Augumentationをするために外部ライブラリのインストールが必要となります。
pip3 install pillow torch
上記コマンドを叩けばインストールできます。
もしくはrequirements.txtを作成してインストールするのもありかもしれません。
5. 実装
from PIL import Image
from torchvision import transforms
import glob
def data_augumentation(name):
path = glob.glob('./datasets/train/{}/*jpg'.format(name))
# Augumentation
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=[-0.5, 0.5]),
transforms.RandomInvert(p=1.0)
])
num = 0
i = 1
for num in range(100):
for p in path:
# 画像の読み込み
img = Image.open(p)
# オーグメンテーションの実行
img = transform(img)
# 編集した画像を保存
img.save('./datasets/train/{}/rekekoiAug_'.format(name)+ str(i) + '.jpg')
i += 1
num += 1
if num >= 100:
break
else:
print(str(num) + 'epoch')
return img
if __name__ == '__main__':
himuro_dataset = data_augumentation('himuro')
ibarada_dataset = data_augumentation('ibarada')
kanade_dataset = data_augumentation('kanade')
kosuke_dataset = data_augumentation('kosuke')
yukimura_dataset = data_augumentation('yukimura')
6. 機能説明
上記コードの説明をしていきます。
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=[-0.5, 0.5]),
transforms.RandomInvert(p=1.0)
])
ここでは、データをどのように拡張するかを指定しています。
Transform はデータに対して行う前処理を行うオブジェクトです。
torchvision では、画像のリサイズや切り抜きといった処理を行うための Transform が用意されています。
今回はその中でも、画像を白黒にするグレースケール変換、明度・彩度・色相・コントラストを変えるColorJitter、色調の反転をするRandomInvertを行なっています。
RandomInvertのpは確率を表しており、ここの値をいじることで指定の確率で変更を行なってくれます。
他にもいろいろ機能はありますが、多すぎるので下記に公式のリンクを貼っておきますので、ぜひご覧いただき、必要に応じて追加してください。
Pytorch Data Augumentation: https://pytorch.org/vision/stable/transforms.html
7. まとめ
このようにPytorchでも比較的簡単にData Augumentationは実装できます。
kerasよりも少し複雑には見えますが、やってることは何も変わらないので、Pytorch独特の書き方にさえなれれば誰でも簡単に実装ができると思います。
ぜひData Augumentationをマスターして素敵なAI開発ライフを送ってください!
8. 参考文献
Pytorch: https://pytorch.org/
Pytorch Data Augumentation: https://pytorch.org/vision/stable/transforms.html