LoginSignup
359
259

More than 3 years have passed since last update.

PyTorch transforms/Dataset/DataLoaderの基本動作を確認する

Last updated at Posted at 2019-01-08

概要

PyTorchの前処理とデータのロードを担当するtransforms/Dataset/DataLoaderの動作を簡単な例で確認する。

この記事の対象読者

  • これからPyTorchを勉強しようとしている人
  • PyTorchのtransforms/Dataset/DataLoaderの役割を知りたい人
  • オリジナルのtransforms/Dataset/DataLoaderを実装したい人

前置き

DeepLearningのフレームワークではだいたい以下のような機能をサポートしている。

  • データの前処理
  • データセットのロード
  • モデル構築
  • ロスの計算
  • オプティマイザによる重みの更新

この中で「データの前処理」と「データセットのロード」は自分達の環境によってカスタマイズすることがよくあるので、フレームワークがどのような機能をサポートしているのかを把握することが実装の効率化に繋がる。

今回はPyTorchの「データの前処理」と「データセットのロード」を実現するためのモジュールtransforms/Dataset/DataLoaderの動きを簡単なデータセットを使って確認してみる。

PyTorch Tutorial

今回は、PyTorchのTutorialのDATA LOADING AND PROCESSING TUTORIALで紹介されている内容を参考にしている。

このTutorialでは実際のデータセット(顔画像と顔の特徴点)を使用してtransforms/Dataset/DataLoaderについて解説している。

Tutorialの内容は分かりやすい構成になっているが、画像データやcsvファイルを扱うので、コード上にこれらのファイル特有の処理が入っており、transforms/Dataset/DataLoaderだけの動きを知りたい場合は、やや冗長な部分がある。

よって今回は数字の1次元のリスト形式のデータセットを使って、transforms/Dataset/DataLoaderを動かしていく。

transforms/Dataset/DataLoaderの役割

  • transforms
    • データの前処理を担当するモジュール
  • Dataset
    • データとそれに対応するラベルを1組返すモジュール
    • データを返すときにtransformsを使って前処理したものを返す。
  • DataLoader
    • データセットからデータをバッチサイズに固めて返すモジュール

上記説明にあるとおり
Datasetはtransformsを制御して
DataLoaderはDatasetを制御する
という関係になっている。

なので流れとしては、
1.Datasetクラスをインスタンス化するときに、transformsを引数として渡す。
2.DataLoaderクラスをインスタンス化するときに、Datasetを引数で渡す。
3.トレーニングするときにDataLoaderを使ってデータとラベルをバッチサイズで取得する。
という流れになる。

以下各詳細を、transforms、Dataset、DataLoaderの順に動作を見ていく。

transforms

transformsはデータの前処理を行う。
PyTorchではあらかじめ便利な前処理がいくつか実装されている。
例えば、画像に関する前処理はtorchvision.transformsにまとまっており、CropやFlipなどメジャーな前処理があらかじめ用意されている。

今回は自分で簡単なtransformsを実装することで処理内容の理解を深めていく。

transformsを実装するのに必要な要件

予め用意されているtransformsの動作に習うために「コール可能なクラス」として実装する必要がある。
(「コール可能」とは__call__を実装しているクラスのこと)
なぜ「コール可能なクラス」にする必要があるのかというと、Tutorialでは以下のように説明している。

We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.

つまり、クラスにしておけば、インスタンス化時に前処理に使うパラメータを全部渡しておけるので、前処理を実行するたびにパラメータを渡す手間が省ける、ということで「コール可能なクラス」を推奨している。

今回はデータとして数字の1次元配列を使うので入力値を二乗するtransformsを実装する。

実装
class Square(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        return sample ** 2

実装はこれだけ。

使い方
transform = Square()
print(transform(1)) # -> 1
print(transform(2)) # -> 4
print(transform(3)) # -> 9
print(transform(4)) # -> 16

渡した数値が二乗になっていることが確認できる。
もし画像データ用のtransformsを実装したい場合は、__call__の中に画像処理を実装すればいい。

Dataset

Datasetは、入力データとそれに対応するラベルを1組返すモジュール。
データはtransformsで前処理を行った後に返す。そのためDatasetを作るときは引数でtransformsを渡す必要がある。

PyTorchでは有名なデータセットがあらかじめtorchvision.datasetsに定義されている。(MNIST/CIFAR/STL10など)

自前のデータを扱いたいときは自分のデータをリードして返してくれるDatasetを実装する必要がある。

扱うデータが画像でクラスごとにフォルダ分けされている場合はtorchvision.datasets.ImageFolderという便利なクラスもある。(KerasのImageDataGeneratorflow_from_directory()のような機能)

Datasetを実装するのに必要な要件

オリジナルDatasetを実装するときに守る必要がある要件は以下3つ。

  • torch.utils.data.Datasetを継承する。
  • __len__を実装する。
  • __getitem__を実装する。

__len__は、len(obj)で実行されたときにコールされる関数。
__getitem__は、obj[i]のようにインデックスで指定されたときにコールされる関数。

今回は、データは数字のリスト、ラベルは偶数の場合だけTrueになるものを出力するDatasetを実装する。

実装
import torch
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, data_num, transform=None):
        self.transform = transform
        self.data_num = data_num
        self.data = []
        self.label = []
        for x in range(self.data_num):
            self.data.append(x) # 0 から (data_num-1) までのリスト
            self.label.append(x%2 == 0) # 偶数ならTrue 奇数ならFalse

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label =  self.label[idx]

        if self.transform:
            out_data = self.transform(out_data)

        return out_data, out_label

ポイントは__getitem__でデータを返す前にtransformでデータに前処理をしてから返しているところ。
データセットとして画像やcsvファイルを扱う場合は、__init__や__getitem__の中でファイルをオープンする必要がある。

使い方
data_set = MyDataset(10, transform=None)
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (2, True)
print(data_set[3]) # -> (3, False)
print(data_set[4]) # -> (4, True)

# 先ほど実装したtransformsを渡してみる.
# データが二乗されていることに注目.
data_set = MyDataset(10, transform=Square())
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (4, True)
print(data_set[3]) # -> (9, False)
print(data_set[4]) # -> (16, True)

指定したインデックスのデータとラベルがセットで取得できている。
次に説明するDataLoaderは、この仕組みを利用してバッチサイズ分のデータを生成する。

DataLoader

データセットからデータをバッチサイズに固めて返すモジュール
DataLoaderはデータセットを使ってバッチサイズ分のデータを生成する。またデータのシャッフル機能も持つ。
データを返すときは、データをtensor型に変換して返す。
tensor型は、計算グラフを保持することができる変数でDeepLearningの勾配計算に不可欠な変数になっている。

DataLoaderは、torch.utils.data.DataLoaderというクラスが既に用意されている。たいていの場合、このクラスで十分対応できるので、今回はこのクラスにこれまで実装してきたDatasetを渡して動作を見てみる。

使い方
import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=True)

for i in dataloader:
    print(i)

# [tensor([ 4, 25]), tensor([1, 0])]
# [tensor([64,  0]), tensor([1, 1])]
# [tensor([36, 16]), tensor([1, 1])]
# [tensor([1, 9]), tensor([0, 0])]
# [tensor([81, 49]), tensor([0, 0])]

指定したバッチサイズでかつデータがシャッフルされていることがわかる。transformsで値が二乗に変換されている。
shuffle=Falseにすると順番にデータが出力される。

import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=False)

for i in dataloader:
    print(i)

# [tensor([0, 1]), tensor([1, 0])]
# [tensor([4, 9]), tensor([1, 0])]
# [tensor([16, 25]), tensor([1, 0])]
# [tensor([36, 49]), tensor([1, 0])]
# [tensor([64, 81]), tensor([1, 0])]

学習のときはdataloaderのループをさらにepochのループでかぶせる。

epochs = 4
for epoch in epochs:
    for i in dataloader:
        # 学習処理

最後に

他にもPyTorchに関する記事を書いたのでPyTorchを勉強し始めの方は参考にしてみてください。

また、PyTorchで実装したものもGithubに公開しています。

359
259
11

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
359
259