Help us understand the problem. What is going on with this article?

PyTorchでValidation Datasetを作る方法

概要

PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasetsを使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。
例題としてtorchvision.datasets.MNISTを使う。

課題

torchvision.datasets.MNISTを使うと簡単にPyTorchのDatasetを作ることができるが、train/test用のDatasetしか用意されていないためvalidation用のDatasetを自分で作る必要がある。

以下のコードはtrain/test用のDatasetを作っている。

from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
test_dataset = datasets.MNIST('./', train=False, download=True)
print(len(trainval_dataset)) # 60000
print(len(test_dataset)) # 10000
print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST
print(type(test_dataset)) # torchvision.datasets.mnist.MNIST

このtrainval_datasetをtrain/validationに分割したい。
しかし、trainval_datasetは単純なリスト形式ではなく、PyTorchのDatasetになっているため、「Datasetが持つデータを取り出して、それをDatasetクラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、DatasetクラスにTransformクラスを渡している場合。)

DatasetクラスとTransformクラスについては以下の記事にまとめました。
PyTorch transforms/Dataset/DataLoaderの基本動作を確認する

解決策1 torch.utils.data.Subset

torch.utils.data.Subset(dataset, indices)を使うと簡単にDatasetを分割できる。
PyTorchの中のコードは以下のようにシンプルなクラスになっている。

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

つまり、Datasetとインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDatasetを生成してくれる。

文章にするとややこしいけどコードの例を見るとわかりやすい。
以下のコードはMNISTの60000のDatasetをtrain:48000とvalidation:12000のDatasetに分割している。

from torchvision import datasets
from torch.utils.data.dataset import Subset
trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

subset1_indices = list(range(0,train_size)) # [0,1,.....47999]
subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999]

train_dataset = Subset(trainval_dataset, subset1_indices)
val_dataset   = Subset(trainval_dataset, subset2_indices)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

解決策2 torch.utils.data.random_split

解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)を使うとランダムに分割することができる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000
val_size = n_samples - train_size # val_size is 48000

# shuffleしてから分割してくれる.
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

Chainerのchainer.datasets.split_dataset_randomについて

ちなみにChainerのchainer.datasets.split_dataset_randomtorch.utils.data.random_splitと同じようなことをしてくれる。

Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。

from torch.utils.data.dataset import Subset

def split_dataset(data_set, split_at, order=None):
    from torch.utils.data.dataset import Subset
    n_examples = len(data_set)

    if split_at < 0:
        raise ValueError('split_at must be non-negative')
    if split_at > n_examples:
        raise ValueError('split_at exceeds the dataset size')

    if order is not None:
        subset1_indices = order[0:split_at]
        subset2_indices = order[split_at:n_examples]
    else:
        subset1_indices = list(range(0,split_at))
        subset2_indices = list(range(split_at,n_examples))

    subset1 = Subset(data_set, subset1_indices)
    subset2 = Subset(data_set, subset2_indices)

    return subset1, subset2

def split_dataset_random(data_set, first_size, seed=0):
    order = np.random.RandomState(seed).permutation(len(data_set))
    return split_dataset(data_set, int(first_size), order)

これを使うとランダムに分割できる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

参考

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away