3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Pytorchのdatasetについて説明してみた

Posted at

#動機
研究室でPytorchを使っている人がいるのでその方向けの説明用に

#本題 - 「なぜdatasetが必要なのか」

  • 深層学習の入力データはバッチごとであるため
  • バカでかいデータを一気に扱うとメモリが死ぬので細かく扱えるdatasetが便利だから
  • とにかくシャッフルとかも勝手にやってくれるから

#簡易的なコードと説明

dataset.py
import torch
from sklearn.datasets import load_iris

class Dataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.iris = load_iris() #irisdatasetの読み込み
        self.data = self.iris['data']
        self.label = self.iris['target']
        self.datanum = len(self.label) #データの総数
        self.transform = transform #データに対する特別な処理

    def __len__(self):
        return self.datanum

    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]

        if self.transform:
            data = self.transform(data)

        return data, label

if __name__ == "__main__":
    batch_size = 20
    dataset = Dataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for data in dataloader:
        print("データ件数: " + str(len(data[0])))
        print("data: {}".format(data[0]))
        print("レーベル件数: " + str(len(data[1])))
        print("label: {}".format(data[1]) + "\n")

以上がコードの全容です。
わかりやすくするために本当に最小限だけにしてます。

__init__はクラスを定義した際の処理です。
今回はとても小さなデータなのでinit内で全部定義しちゃってますが、
もしとても大きなデータをストレージからイテレーションしたい場合はここでパスなどを指定し、getitemで順番にイテレーションするといい感じです。

__len__はデータの総数を返します。

__getitem__はindexで指定されたデータを返します。
Dataloaderでイテレーションする際は後述するバッチサイズ分だけデータをまとめて指定して返します。

#Dataloaderについて
第一引数にDatasetクラスのインスタンスであるdatasetを渡します。
第二引数や第三引数には、バッチサイズやshuffleをするか__(True/False)__を渡します。
特別な理由がない限りshuffleはTrueにしましょう。

for 以外にもDataloaderをイテレーションする方法はありますが、ここではforを使います。
上記のプログラムを実行したのが以下です。

データ件数: 20
data: tensor([[6.8000, 3.0000, 5.5000, 2.1000],
        [6.7000, 3.1000, 5.6000, 2.4000],
        [5.4000, 3.9000, 1.3000, 0.4000],
        [5.5000, 2.4000, 3.7000, 1.0000],
        [5.1000, 3.7000, 1.5000, 0.4000],
        [4.5000, 2.3000, 1.3000, 0.3000],
        [6.6000, 2.9000, 4.6000, 1.3000],
        [6.5000, 3.0000, 5.8000, 2.2000],
        [7.0000, 3.2000, 4.7000, 1.4000],
        [4.4000, 3.2000, 1.3000, 0.2000],
        [5.0000, 3.4000, 1.5000, 0.2000],
        [5.4000, 3.4000, 1.5000, 0.4000],
        [4.9000, 2.4000, 3.3000, 1.0000],
        [6.3000, 3.4000, 5.6000, 2.4000],
        [7.7000, 2.6000, 6.9000, 2.3000],
        [6.2000, 2.8000, 4.8000, 1.8000],
        [6.2000, 3.4000, 5.4000, 2.3000],
        [5.6000, 2.7000, 4.2000, 1.3000],
        [6.1000, 3.0000, 4.9000, 1.8000],
        [6.7000, 3.0000, 5.0000, 1.7000]], dtype=torch.float64)
レーベル件数: 20
label: tensor([2, 2, 0, 1, 0, 0, 1, 2, 1, 0, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1],
       dtype=torch.int32)

~~~~~途中省略~~~~~~

データ件数: 10
data: tensor([[4.8000, 3.4000, 1.6000, 0.2000],
        [6.1000, 2.8000, 4.7000, 1.2000],
        [5.1000, 3.8000, 1.9000, 0.4000],
        [6.7000, 3.3000, 5.7000, 2.1000],
        [6.4000, 2.9000, 4.3000, 1.3000],
        [7.4000, 2.8000, 6.1000, 1.9000],
        [6.4000, 3.2000, 5.3000, 2.3000],
        [5.0000, 3.3000, 1.4000, 0.2000],
        [5.0000, 3.2000, 1.2000, 0.2000],
        [5.8000, 2.7000, 4.1000, 1.0000]], dtype=torch.float64)
レーベル件数: 10
label: tensor([0, 1, 0, 2, 1, 2, 2, 0, 0, 1], dtype=torch.int32)

データもレーベルもバッチサイズとして定義した20ずつで正しく出力してくれています。
150/20はあまり10なのですがDataloaderでは、エラーなく10件出力してくれています。
特別な処理なしでこのように合わせてくれるのも便利な点です。

#最後に
簡単な例ですので別途質問があればコメントください。
誤りがあった場合もお願いします。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?