Help us understand the problem. What is going on with this article?

PytorchのDataloaderとSamplerの使い方

More than 1 year has passed since last update.

Dataloaderとは

datasetsからバッチごとに取り出すことを目的に使われます。
基本的にtorch.utils.data.DataLoaderを使います。
イメージとしてはdatasetsはデータすべてのリスト、Dataloaderはそのdatasetsの中身をミニバッチごとに固めた集合のような感じだと自分で勝手思ってます。
datsets = [データセット全て]
Dataloader = [[batch_1], [batch_2], ... [batch_n]]
ですので
len(datasets)="すべてのデータの数"
len(Dataloader)="イテレーションの数"
となります。
(もちろんイテレータなのですべてのデータがそのまま入っているわけではありません。あくまでイメージです。あとスライス[:x]で取り出すこともできません。)

Dataloaderは__iter____next__が定義されているので、
iter(Dataloader).next() としてあげれば最初から1バッチずつ取り出すことができます。
デバッグする時等に便利です。

datasetsやdataloaderの挙動はPyTorch transforms/Dataset/DataLoaderの基本動作を確認するがわかりやすいです。

samplerとは

samplerとはDataloaderの引数で、datasetsのバッチの固め方を決める事のできる設定のようなものです。
基本的にsamplerはデータのインデックスを1つづつ返すようクラスになっています。

通常の学習ではtestloader = torch.utils.data.DataLoader(testset, batch_size=n,shuffle=True)で事足りると思います。
しかし訓練画像がクラスごとに大きく偏りがあって同じ割合で提供したいときや、距離学習などでそれぞれのクラスから同じ数だけ取り出してネットワークに入れるときなど、少し特殊なミニバッチを作りたいときにsamplerが役に立ちます。

torch.utils.dataに4つほどsamplerが用意されていますが、あまり使いそうなものはないので今回は自作してみたいと思います。
ただし作るのはsamplerではなくbatch_samplerになります。batch_samplerもDataloaderの引数の一つで、1つずつではなく複数のデータのインデックスを返します。
今回の想定として、先程例に出したようなすべてのクラスから何個か選んで、そのそれぞれのクラスから同じ数だけ取り出すことを考えます。選ぶクラスの数をn_classes、1つのクラスから取り出す数をn_samplesとします。すべてのデータの数はn_classes*n_samplesになります。

今回のコードはこちらを参考にしました。

Sampler.py
import numpy as np
import torch

from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler


class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset)
        self.labels_list = []
        for _, label in loader:
            self.labels_list.append(label)
        self.labels = torch.LongTensor(self.labels_list)
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

samplerもbatch_samplerも__iter__を定義する必要があります。
__iter__ではn_classes*n_samples個のインデックスを返しています。
__len__はデータセットの数//バッチサイズなのでlen(batch_sampler)="イテレーションの数"です。

実際にどのように取り出されているか確かめてみましょう。今回はmnistで実験してみたいと思います。

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

n_classes = 5
n_samples = 8

mnist_train =  torchvision.datasets.MNIST(root="mnist/mnist_train", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),]))

balanced_batch_sampler = BalancedBatchSampler(mnist_train, n_classes, n_samples)

dataloader = torch.utils.data.DataLoader(mnist_train, batch_sampler=balanced_batch_sampler)
my_testiter = iter(dataloader)
images, target = my_testiter.next()


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

imshow(torchvision.utils.make_grid(images))

image.png

0から9の内5種類が選ばれ、8枚ずつ取り出されているのがわかると思います。

参考

PyTorch transforms/Dataset/DataLoaderの基本動作を確認する
【詳細(?)】pytorch入門 〜CIFAR10をCNNする〜
pytorh公式サイトのsampler

tomp
大学生。 強化学習と物体検出、セグメンテーション等に興味があります。 備忘録的にアウトプットしていきたい。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした