#概要
PyTorchの前処理とデータのロードを担当するtransforms/Dataset/DataLoaderの動作を簡単な例で確認する。
この記事の対象読者
- これからPyTorchを勉強しようとしている人
- PyTorchのtransforms/Dataset/DataLoaderの役割を知りたい人
- オリジナルのtransforms/Dataset/DataLoaderを実装したい人
#前置き
DeepLearningのフレームワークではだいたい以下のような機能をサポートしている。
- データの前処理
- データセットのロード
- モデル構築
- ロスの計算
- オプティマイザによる重みの更新
この中で「データの前処理」と「データセットのロード」は自分達の環境によってカスタマイズすることがよくあるので、フレームワークがどのような機能をサポートしているのかを把握することが実装の効率化に繋がる。
今回はPyTorchの「データの前処理」と「データセットのロード」を実現するためのモジュールtransforms/Dataset/DataLoaderの動きを簡単なデータセットを使って確認してみる。
PyTorch Tutorial
今回は、PyTorchのTutorialのDATA LOADING AND PROCESSING TUTORIALで紹介されている内容を参考にしている。
このTutorialでは実際のデータセット(顔画像と顔の特徴点)を使用してtransforms/Dataset/DataLoaderについて解説している。
Tutorialの内容は分かりやすい構成になっているが、画像データやcsvファイルを扱うので、コード上にこれらのファイル特有の処理が入っており、transforms/Dataset/DataLoaderだけの動きを知りたい場合は、やや冗長な部分がある。
よって今回は数字の1次元のリスト形式のデータセットを使って、transforms/Dataset/DataLoaderを動かしていく。
transforms/Dataset/DataLoaderの役割
-
transforms
- データの前処理を担当するモジュール
-
Dataset
- データとそれに対応するラベルを1組返すモジュール
- データを返すときにtransformsを使って前処理したものを返す。
-
DataLoader
- データセットからデータをバッチサイズに固めて返すモジュール
上記説明にあるとおり
Datasetはtransformsを制御して
DataLoaderはDatasetを制御する
という関係になっている。
なので流れとしては、
1.Datasetクラスをインスタンス化するときに、transformsを引数として渡す。
2.DataLoaderクラスをインスタンス化するときに、Datasetを引数で渡す。
3.トレーニングするときにDataLoaderを使ってデータとラベルをバッチサイズで取得する。
という流れになる。
以下各詳細を、transforms、Dataset、DataLoaderの順に動作を見ていく。
transforms
transformsはデータの前処理を行う。
PyTorchではあらかじめ便利な前処理がいくつか実装されている。
例えば、画像に関する前処理はtorchvision.transformsにまとまっており、CropやFlipなどメジャーな前処理があらかじめ用意されている。
今回は自分で簡単なtransformsを実装することで処理内容の理解を深めていく。
transformsを実装するのに必要な要件
予め用意されているtransformsの動作に習うために**「コール可能なクラス」として実装する必要がある。
(「コール可能」とは__call__**を実装しているクラスのこと)
なぜ「コール可能なクラス」にする必要があるのかというと、Tutorialでは以下のように説明している。
We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.
つまり、クラスにしておけば、インスタンス化時に前処理に使うパラメータを全部渡しておけるので、前処理を実行するたびにパラメータを渡す手間が省ける、ということで「コール可能なクラス」を推奨している。
今回はデータとして数字の1次元配列を使うので入力値を二乗するtransformsを実装する。
実装
class Square(object):
def __init__(self):
pass
def __call__(self, sample):
return sample ** 2
実装はこれだけ。
使い方
transform = Square()
print(transform(1)) # -> 1
print(transform(2)) # -> 4
print(transform(3)) # -> 9
print(transform(4)) # -> 16
渡した数値が二乗になっていることが確認できる。
もし画像データ用のtransformsを実装したい場合は、__call__の中に画像処理を実装すればいい。
Dataset
Datasetは、入力データとそれに対応するラベルを1組返すモジュール。
データはtransformsで前処理を行った後に返す。そのためDatasetを作るときは引数でtransformsを渡す必要がある。
PyTorchでは有名なデータセットがあらかじめtorchvision.datasetsに定義されている。(MNIST/CIFAR/STL10など)
自前のデータを扱いたいときは自分のデータをリードして返してくれるDatasetを実装する必要がある。
扱うデータが画像でクラスごとにフォルダ分けされている場合はtorchvision.datasets.ImageFolder
という便利なクラスもある。(KerasのImageDataGenerator
のflow_from_directory()
のような機能)
Datasetを実装するのに必要な要件
オリジナルDatasetを実装するときに守る必要がある要件は以下3つ。
- torch.utils.data.Datasetを継承する。
- **__len__**を実装する。
- **__getitem__**を実装する。
__len__は、len(obj)で実行されたときにコールされる関数。
__getitem__は、obj[i]のようにインデックスで指定されたときにコールされる関数。
今回は、データは数字のリスト、ラベルは偶数の場合だけTrueになるものを出力するDatasetを実装する。
実装
import torch
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_num, transform=None):
self.transform = transform
self.data_num = data_num
self.data = []
self.label = []
for x in range(self.data_num):
self.data.append(x) # 0 から (data_num-1) までのリスト
self.label.append(x%2 == 0) # 偶数ならTrue 奇数ならFalse
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx]
out_label = self.label[idx]
if self.transform:
out_data = self.transform(out_data)
return out_data, out_label
ポイントは__getitem__でデータを返す前にtransformでデータに前処理をしてから返しているところ。
データセットとして画像やcsvファイルを扱う場合は、__init__や__getitem__の中でファイルをオープンする必要がある。
使い方
data_set = MyDataset(10, transform=None)
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (2, True)
print(data_set[3]) # -> (3, False)
print(data_set[4]) # -> (4, True)
# 先ほど実装したtransformsを渡してみる.
# データが二乗されていることに注目.
data_set = MyDataset(10, transform=Square())
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (4, True)
print(data_set[3]) # -> (9, False)
print(data_set[4]) # -> (16, True)
指定したインデックスのデータとラベルがセットで取得できている。
次に説明するDataLoaderは、この仕組みを利用してバッチサイズ分のデータを生成する。
DataLoader
データセットからデータをバッチサイズに固めて返すモジュール
DataLoaderはデータセットを使ってバッチサイズ分のデータを生成する。またデータのシャッフル機能も持つ。
データを返すときは、データをtensor型に変換して返す。
tensor型は、計算グラフを保持することができる変数でDeepLearningの勾配計算に不可欠な変数になっている。
DataLoaderは、torch.utils.data.DataLoader
というクラスが既に用意されている。たいていの場合、このクラスで十分対応できるので、今回はこのクラスにこれまで実装してきたDatasetを渡して動作を見てみる。
使い方
import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=True)
for i in dataloader:
print(i)
# [tensor([ 4, 25]), tensor([1, 0])]
# [tensor([64, 0]), tensor([1, 1])]
# [tensor([36, 16]), tensor([1, 1])]
# [tensor([1, 9]), tensor([0, 0])]
# [tensor([81, 49]), tensor([0, 0])]
指定したバッチサイズでかつデータがシャッフルされていることがわかる。transformsで値が二乗に変換されている。
shuffle=False
にすると順番にデータが出力される。
import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=False)
for i in dataloader:
print(i)
# [tensor([0, 1]), tensor([1, 0])]
# [tensor([4, 9]), tensor([1, 0])]
# [tensor([16, 25]), tensor([1, 0])]
# [tensor([36, 49]), tensor([1, 0])]
# [tensor([64, 81]), tensor([1, 0])]
学習のときはdataloaderのループをさらにepochのループでかぶせる。
epochs = 4
for epoch in epochs:
for i in dataloader:
# 学習処理
最後に
他にもPyTorchに関する記事を書いたのでPyTorchを勉強し始めの方は参考にしてみてください。
- PyTorchでValidation Datasetを作る方法
- PyTorch 入力画像と教師画像の両方にランダムなデータ拡張を実行する方法
- Kerasを勉強した後にPyTorchを勉強して躓いたこと
また、PyTorchで実装したものもGithubに公開しています。