39
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのDataSetとDataLoaderを理解する(1)

Last updated at Posted at 2019-11-07

はじめに

気がつけばあまり理解せずに使っていたPyTorchのDataLoaderとDataSetです。
少し凝ったことがしたくなったら参考にしていただければ幸いです。

後編はこちら

PyTorchのExampleの確認

PyTorchを使っていれば、当然DataLoaderを見たことがあると思います。
誰もが機械学習で使うMNISTのPyTorchのExampleでもこんな記述があります。

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/dataset/MNIST',
                    train=True,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=256,
    shuffle=True)

あるいはQiitaなどで検索するとこんな書き方も見ると思います。

train_dataset = datasets.MNIST(
    '~/dataset/MNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True)

そしてその後、For文で回すことでバッチでデータを取得し、学習する様子を見ると思います。

for epoch in epochs:
    for img, label in train_loader:
        # この中で学習の処理を記載する

口語的に書けば、DataLoaderというのは、あるルールに従って、DataSetで記載の通りにデータを運んできてくれる便利なヤツです。
例えば上記の例では、MNISTのデータがNormalizeされた状態で256個(ミニバッチ)ずつimgとlabelにいい感じに入ってきてくれます。
ではどうやってそれが実現されているのか、中身を見に行きましょう。

torch.utils.data.DataLoaderを見てみる

DataLoaderの実装を見てみます。
クラスになっているのがすぐにわかります。

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    """
    # 省略

詳細は割愛しますが、イテレータとしての情報は少し深く見ていくと下記のような実装が見つかります。

def __next__(self):
    index = self._next_index()  # may raise StopIteration
    data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
    if self.pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data

この__next__が呼ばれると、dataが返される仕組みです。
そしてこのdataは、datasetにindexを渡して作られているようです。
この段階ではindexがどう作られるか、datasetがどう呼ばれているか、はそこまで神経質にならなくても大丈夫ですが、せっかくなのでもう一歩奥を見に行きましょう。

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
            if self.auto_collation:
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)

datasetにindexが渡されていますね。このようにクラスのインスタンスを呼んでいるということは、datasetでは__getitem__が呼ばれているはずです。(こちらが詳しいです。ではこれを踏まえてdatasetを見に行ってみましょう。

datasets.MNISTを見てみる

MNISTの定義を見に行くとすぐにClassであることがわかります。

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    """

    # 省略

では__getitem__を見に行きましょう。

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

indexが渡され、それに対応するデータをreturnしている様子がわかりやすく書かれています。なるほどこうやってMNISTのデータを返しているのです。
少し処理を見てみると、PILのImage.fromarrayなんかも書いてあります。つまりこの__getitem__を工夫して書いてあげれば、自在なデータをリターンすることが可能だということです。

torch.utils.data.DataLoaderをもう1回見てみる

だけどまだわからないことがあります。どんなふうにindexが作られているのでしょうか。そのヒントはここにあります。

if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)

self.sampler = sampler
@property
def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

indexはsamplerを通じて作られているようです。
samplerはデフォルトではshuffleという引数のTrue,Falseによって切り替わっています。例えばshuffle=Falseのときの実装を見てみましょう。

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

ここでいうdata_sourceは、datasetのことです。ここまで来て大体全容が掴めたように見えます。つまり、datasetの長さだけ繰り返すということです。逆に言えばdatasetの中で__len__という特殊メソッドを用意しておく必要がありそうですね。

datasets.MNISTをもう一度見てみる

ではdatasets.MNISTの__len__を確認しておきましょう。

def __len__(self):
        return len(self.data)

dataの長さを返していますね。MNISTでのdataは60000x28x28のサイズなので、60000が返ることになります。だいぶすっきりしてきました。

次回に続く

記事が長くなって来たので、前編はここまで。
後編ではdatasetの自作を行います。

39
29
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?