Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
149
Help us understand the problem. What is going on with this article?
@mathlive

pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用

2019/9/29 投稿
2019/11/8 やや見やすく編集(主観)

0. この記事の対象者

  • pythonを触ったことがあり,実行環境が整っている人
  • pyTorchをある程度触ったことがある人
  • pyTorchとtorchvisionのtransforms,Datasets,dataloaderを深く理解したい人
  • 既存のDatasetから自作のDatasetを作成したい人

1. はじめに

昨今では機械学習に対してpython言語による研究が主である.なぜならpythonにはデータ分析や計算を高速で行うためのライブラリ(moduleと呼ばれる)がたくさん存在するからだ.
その中でも今回はpyTorchと呼ばれるmoduleを使用し,そこで提供されるDatasetから,自作Datasetの作成,その使用までを行う.

この記事では自作Datasetの作成とその使用のみを行い,実際にNetworkを使って学習をしたりはしない.
そちらに興味がある場合は以下のLinkを参照してほしい.

pyTorchでCNNsを徹底解説

2. 事前知識

pythonには他言語同様,「」というものが定義した変数には割り当てられており,中でも「list型」,「tuple型」,「dictionary型」がよく出てくるように思う.さらに,moduleとして「numpy」というものもあり,このnumpyが持つ特殊な型,「ndarray型」もよく出てくる.

そして,pyTorchにはTensor型という特殊な型が用意されており,機械学習ではこのdataを多用する.
Tensor型についての説明は以下Linkを見てほしい.

pyTorchのTensor型とは

ここではあえて説明はしないが,わからない人は是非検索をしてそれぞれをしっかり理解しておいてほしい.

3. pyTorchのインストール

pyTorchを初めて使用する場合,pythonにはpyTorchがまだインストールされていないためcmdでのインストールをしなければならない.
下記のLinkに飛び,ページの下の方にある「QUICK START LOCALLY」で自身の環境のものを選択し,現れたコマンドをcmd等で入力する(コマンドをコピペして実行で良いはず).

pytorch 公式サイト

さらに,今回は「torchvision」というmoduleも使用するためこちらもインストールしておいてほしい.
コマンドは以下に示す(condaを使用している場合).

conda install torchvision

4. pyTorchのimport

ここからはcmd等ではなくpythonファイルに書き込んでいく.
下記のコードを書くことでmoduleの使用をする.

filename.py
import torch
import torchvision

ついでにnumpyもimportしておく.

filename.py
import numpy as np

5. Datasetの使い方とDatasetの自作

今回はtorchvisionに用意されているCIFAR10というDatasetを用いて,dataの部分はgray scaleに,labelを通常のCIFAR10のcolor scaleにする.
こういったDatasetはAuto EncoderやUNetのexerciseでよく使用するもので,とても重要なものである.
 
まずは以下にpyTorchがどうやってDatasetを扱うかを詳しく説明し,その後自作Datasetを作成する.

5-1. pyTorchの通常のDataset使用

torchvisionには主要なDatasetがすでに用意されており,たった数行のコードでDatasetのダウンロードから前処理までを可能とする.

結論から言うと3行のコードでDatasetの運用が可能となり,ステップごとに言えば,

  1. transformsによる前処理の定義
  2. Datasetsによる前処理&ダウンロード
  3. DataloaderによるDatasetの使用

という流れになる.

以下にそれぞれを説明する.

5-1-1. transformsによる前処理の定義

以下にtransformsの例を示す.

filename.py
trans = torchvision.transforms.ToTensor()

これは画像であるPIL image または ndarrayのdata「Height×Width×Channel」をTensor型のdata「Channel×Height×Width」に変換するというもので,transという変数がその機能を持つことを意味する.
なぜChannelの順が入れ替わっているかというと,機械学習をしていく上でChannelが最初のほうが都合が良いからだと思ってもらって良い.
さらに実は画像の各輝度値の範囲を自動で [0.0,1.0]にしてくれている.

この「torchvision.transforms.ToTensor()」はclassでtransはクラスインスタンスのようなものだ.

使い方は以下のようにすればよい.

filename.py
Tensor型data = trans(PILまたはndarray)

このようにtransformsは「trans(data)」のように使えるということが重要である.
これは「trans()」がその機能を持つclass 「torchvision.transforms.ToTensor()」の何かを呼び出しているのだ.

ここで例えばTensor変換だけでなく正規化を同時にしたい場合は以下のようにする.

filename.py
trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])

torchvision.transforms.Composeは引数で渡されたlist型の[~~,~~,...]というのを先頭から順に実行していくものである.そのためlist内の前処理の順番には十分注意する.
こうすることでtransという変数はTensor変換と正規化を一気にしてくれるハイブリッドな変数になった.
その他のtransformsは以下のLinkより確認してほしい.

torchvision.transforms 公式サイト

より理解を深めるために,自身で簡単にtransformsを自作してみる.
以下サンプル.

filename.py
class Plus2(object):
    def __init__(self):
        pass
    def __call__(self, x):
        data = x + 2
        return data

これは入力xに対して2を足すだけのtransformsで使用方法は以下のようにすれば良い.

filename.py
trans = Plus2()
x = 9
data = trans(x)
print(x)
print(data)

------'''以下出力結果'''--------
9
11

このように全く同じように動作しているのがわかった.
つまり,transformsは「__call__()」という関数が重要で,この中に書いた処理が実行されているのだ.

5-1-2. Datasetsによる前処理&ダウンロード

以下にダウンロードを示す.

filename.py
trainset = torchvision.datasets.MNIST(root = 'path', train = True, download = True, transform = trans)

まずは引数の説明をしていく.

  • root」はDatasetを参照(または保存)するディレクトリを「path」の部分に指定する.そのディレクトリに取得したいDatasetが存在すればダウンロードせずにそれを使用する.

  • train」はTraining用のdataを取得するかどうかを選択する.FalseにすればTest用のdataを取得するが,この2つの違いはdata数の違いと思ってくれて良い.

  • download」は参照したディレクトリにDatasetがない場合ダウンロードするかどうかを決めることができる.

  • transform」は定義した前処理を渡す.こうすることでDataset内のdataを「参照する際」にその前処理を自動で行ってくれる.

今回はMNISTを使用したが,他の使用できるDatasetは下記のLinkより参照して使用して欲しい.その時のコードも大体同じである.

torchvision.datasets 公式サイト

取得したtrainsetをそのまま出力してみると以下のようなDatasetの内容が表示されるはずだ.

filename.py
print(trainset)

------'''以下出力結果'''--------
Dataset MNIST
    Number of datapoints: 60000
    Root location: rootで指定したpathが出るはず
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

これだけ見るとDatasetなのにどうやってdataを見ているの?となるが,dataの参照は以下のようにすれば良い.

filename.py
print(trainset[0])

------'''以下出力結果'''--------
(tensor([data内容]), そのdataに対応する正解label)

これでDatasetの0番目のdataを参照できる.
実は実際にDatasetのdataを使用するときも配列の参照のようにdataを参照している.

ちなみにdataの前処理であるTransformによるエラーは,この参照をして初めて出力される.なぜなら先でも言ったがTransformはデータが参照される時に使用されるからだ.
実際のコードではデータ参照はかなり見えないところで行われるためエラーが出てもどこのエラーなのか読み取りづらい.
Transformの処理に不安があるならば必ず事前に参照出力させておくとそこでエラー確認ができるため安全だと思う.

より理解を深めるために,自身で簡単なdatasetsを自作してみる.
以下サンプル.

filename.py
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, transform = None):
        self.transform = transform

        self.data = [1, 2, 3, 4, 5, 6]
        self.label = [0, 1, 0, 1, 0, 1]

        self.datanum = 6

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        if self.transform:
            out_data = self.transform(out_data)

        return out_data, out_label

self.data = [1, 2, 3, 4, 5, 6]」と「self.label = [0, 1, 0, 1, 0, 1]」より,このDatasetは1から6の数字のdataと,奇数が0偶数が1となるlabelを持つということがわかる.
使用は以下のようにする.

filename.py
t = Plus2()
dataset = Mydatasets(t)
print(len(dataset))
print(dataset[5])
print(dataset)

------'''以下出力結果'''--------
6
(8, 1)
<__main__.Mydatasets object at 0x7f3c7e0eb6d8>

出力からわかるように,classの「def __len__()」は「len(dataset)」とすると実行され,dataの長さを返す関数である.
def __getitem__()」は「dataset[5]」とするとその番号のdataとlabelを返す関数となっている.ただしこのときにtransformであるPlus2()が実行していることに注意する.
このようにdatasetsは「def __len__()」と「def __getitem__()」が必須なのである.

5-1-3. DataloaderによるDatasetの使用

DataloaderによるDatasetの使用は下記のコードで実行する.

filename.py
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 100, shuffle = True, num_workers = 2)

まずは引数の説明をしていく.

  • 第1引数は先程取得したDatasetを入れる.

  • batch_size」は1回のtrainingまたはtest時に一気に何個のdataを使用するかを選択.datasetの全data数を割り切れる値にしなければならない.

  • shuffle」はdataの参照の仕方をランダムにする.

  • num_workers」は複数処理をするかどうかで,2以上の場合その値だけ並行処理をする.

取得したtrainloaderを出力しても以下のようなオブジェクトタイプしか表示されない.

filename.py
print(trainloader)

------'''以下出力結果'''--------
<torch.utils.data.dataloader.DataLoader object at 0x7fdffa11ada0>

Datasetのように中身が見たい場合,配列の参照のようにするとエラーが起こる.なぜならDataLoaderは配列ではなくiteratorというものを返しているためである.
無理やり中身を見ようとするならば以下のようにすれば良い.

filename.py
for data,label in trainloader:
    break
print(data)
print(label)

------'''以下出力結果'''--------
tensor([[data1], [data2],..., [data100]])
tensor([label1, label2,..., label100])

このように1回の取得でdataとlabelはバッチサイズだけ取得され,もちろん各dataとlabelは対応しあっている.
ただし,この確認は絶対に学習する前に同プログラム内ではやってはいけない.
なぜならtrainloaderはiteratorであるため今回呼び出したdataは全データを見きるまで二度と見られることがなくなってしまう.
つまりこの参照をしてしまった100個のdataは2周目に入るまで見られなくなってしまう.
そこに十分注意してほしい.

dataloaderは自作する意味はあまりなく,datasetさえ作ってしまえばいつものdataloaderの使い方でできるので,今回はそのまま使用する.
(なぜ自作のdatasetをdataloaderが同じように使えるのかは,dataloaderが見ることができるようにdatasetやtransformを作成していくからである)

5-2. Datasetの作成

冒頭でも述べたように,今回は既存のCIFAR10というdatasetからdataをgray scaleに,labelをcolor scaleにしたdatasetにする.

ただし,とても簡単な方法でやるため,あまりスマートではない.
よりスマートな方法はもちろんあるので,是非自身で考えてプログラムを組んでほしい.

5-2-1. transformの準備

まずtransformを準備する.

filename.py
trans1 = torchvision.transforms.ToTensor()
trans2 = torchvision.transforms.Compose([torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()])

trans1」はlabel用のtransform.
trans2」はdata用のtransformでgray scaleにするためのtransformである「torchvision.transforms.Grayscale()」を用意している.

5-2-2. mydatasetの準備

自作Datasetを準備する.

filename.py
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, path, transform1 = None, transform2 = None, train = True):
        self.transform1 = transform1
        self.transform2 = transform2
        self.train = train

        self.labelset = torchvision.datasets.CIFAR10(root = path, train = self.train, download = True)
        self.dataset = torchvision.datasets.CIFAR10(root = path, train = self.train, download = True)

        self.datanum = len(dataset)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_label = self.labelset[idx][0]
        out_data = self.dataset[idx][0]

        if self.transform1:
            out_label = self.transform1(out_label)

        if self.transform2:
            out_data = self.transform2(out_data)

        return out_data, out_label

このようにdata用とlabel用で二度CIFAR10を取得し,それぞれのdata部分だけを取り出してtransformを適用している(もともとのCIFAR10のlabelは取得はするが不必要なので無視).

使用と確認は以下のようにすれば良い.

filename.py
dataset = Mydatasets('自分のpath', transform1=trans1, transform2=trans2, train=True)
print(dataset[0][0].shape)
print(dataset[0][1].shape)

------'''以下出力結果'''--------
torch.Size([1, 32, 32])
torch.Size([3, 32, 32])

pyTorchで扱うTensor型dataは「Channel×Height×Width」となっており,先頭の値がChannel数である.
gray scaleは1channelでcolor scaleは3channelだから,うまく実装できた.

一応,以下で画像も出力して確認できる.

filename.py
from PIL import Image
import matplotlib.pyplot as plt

ts = torchvision.transforms.ToPILImage()
im = ts(dataset[0][0])
plt.imshow(np.array(im),  cmap='gray')
plt.show()

im = ts(dataset[0][1])
plt.imshow(np.array(im))

出力は自身で確認してほしい.
この「torchvision.transforms.ToPILImage()」がTensor型をPIL imageに変換しており,この際「Channel×Height×Width」を「Height×Width×Channel」に自動変換してくれている.

5-2-3. dataloaderによる使用

最後のdataloaderによる使用は以下のようにするだけである.

filename.py
trainloader = torch.utils.data.DataLoader(dataset, batch_size = 100, shuffle = True, num_workers = 2)

6. おまけ,Tensor⇒ndarrayにするtransformの作成

上で画像出力の確認をした際,「torchvision.transforms.ToPILImage()」を使用し,更にその後np.array(im)とすることでTensor型⇒PIL image⇒ndarrayとしている.
なぜここまで回りくどいことをしているかは,以下の理由があるからである.

  • 画像を出力するmoduleはTensorの「Channel×Height×Width」は扱えず,「Height×Width×Channel」しか扱えない
  • Tensor⇒ndarrayにする際は「xxx..numpy()」を使用するが,この場合「Channel×Height×Width」の形のままになってしまう
  • ToPILImage()」は「Height×Width×Channel」にしてくれる
  • 画像出力をするmatplotlib moduleはndarrayを使用するため最終的にndarrayにしたい

この変換を一括でしてくれるようなtransformsを作成する.
以下ソースコード.

filename.py
class ToNDarray(object):
    def __init__(self):
        pass

    def __call__(self, x):
        x_shape = x.shape    #x=(C,H,W)
        x = x.detach().clone().cpu()   #x=(C,H,W)
        x = x.numpy()   #x=(C,H,W)
        if x_shape[0] == 1:       #C=1の時
            x = x[0]    #x=(H,W)にする
        else:
            x = x.transpose(1,2,0)  #x=(H,W,C)にする
        return x

これはTensor型のdataを「detach().clone().cpu()」により勾配情報の無視とGPUからCPUを使用するようにしている.
numpy()」によりTensor型からndarrayに変換する.
if文はchannelが1かそれ以上かを確認している.
channelが1つの場合はgray画像となるのだが,その場合,出力するときはchannel情報が不必要なため見ないようにしている.
また「transpose(1,2,0)」は「Channel×Height×Width」を「Height×Width×Channel」の順に変換している.

使い方は以下のようにすれば良い.

filename.py
trans = ToNDarray()
im = trans(dataset[0][0])
plt.imshow(im,  cmap='gray')

出力は自身で確認してほしい.

7. まとめソースコード

以下にまとめのソースコードを示す.

filename.py
import torch
import torchvision
import numpy as np

class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, path, transform1 = None, transform2 = None, train = True):
        self.transform1 = transform1
        self.transform2 = transform2
        self.train = train

        self.labelset = torchvision.datasets.CIFAR10(root = path, train = self.train, download = True)
        self.dataset = torchvision.datasets.CIFAR10(root = path, train = self.train, download = True)

        self.datanum = len(dataset)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_label = self.labelset[idx][0]
        out_data = self.dataset[idx][0]

        if self.transform1:
            out_label = self.transform1(out_label)

        if self.transform2:
            out_data = self.transform2(out_data)

        return out_data, out_label

trans1 = torchvision.transforms.ToTensor()
trans2 = torchvision.transforms.Compose([torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()])
dataset = Mydatasets('自身のpath', transform1=trans1, transform2=trans2, train=True)
trainloader = torch.utils.data.DataLoader(dataset, batch_size = 100, shuffle = True, num_workers = 2)

8. ひとこと

今回はpyTorchによる自作Datasetの作成とその説明をさせていただいた.
読みづらい点も多かったと思うが読んでいただきありがとうございます.

149
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
mathlive
自分が今までどのようにプログラミングを理解し,どのようにコードを書いたり読んだり調べたりしたかを,誰かの助けになればいいなと思いながら記事を書いています.

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
149
Help us understand the problem. What is going on with this article?