97
73

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorchでValidation Datasetを作る方法

Last updated at Posted at 2019-03-31

概要

PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasetsを使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。
例題としてtorchvision.datasets.MNISTを使う。

課題

torchvision.datasets.MNISTを使うと簡単にPyTorchのDatasetを作ることができるが、train/test用のDatasetしか用意されていないためvalidation用のDatasetを自分で作る必要がある。

以下のコードはtrain/test用のDatasetを作っている。

from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
test_dataset = datasets.MNIST('./', train=False, download=True)
print(len(trainval_dataset)) # 60000
print(len(test_dataset)) # 10000
print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST
print(type(test_dataset)) # torchvision.datasets.mnist.MNIST

このtrainval_datasetをtrain/validationに分割したい。
しかし、trainval_datasetは単純なリスト形式ではなく、PyTorchのDatasetになっているため、「Datasetが持つデータを取り出して、それをDatasetクラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、DatasetクラスにTransformクラスを渡している場合。)

DatasetクラスとTransformクラスについては以下の記事にまとめました。
PyTorch transforms/Dataset/DataLoaderの基本動作を確認する

解決策1 torch.utils.data.Subset

torch.utils.data.Subset(dataset, indices)を使うと簡単にDatasetを分割できる。
PyTorchの中のコードは以下のようにシンプルなクラスになっている。

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

つまり、Datasetとインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDatasetを生成してくれる。

文章にするとややこしいけどコードの例を見るとわかりやすい。
以下のコードはMNISTの60000のDatasetをtrain:48000とvalidation:12000のDatasetに分割している。

from torchvision import datasets
from torch.utils.data.dataset import Subset
trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

subset1_indices = list(range(0,train_size)) # [0,1,.....47999]
subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999]
        
train_dataset = Subset(trainval_dataset, subset1_indices)
val_dataset   = Subset(trainval_dataset, subset2_indices)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

解決策2 torch.utils.data.random_split

解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)を使うとランダムに分割することができる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000
val_size = n_samples - train_size # val_size is 48000

# shuffleしてから分割してくれる.
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

Chainerのchainer.datasets.split_dataset_randomについて

ちなみにChainerのchainer.datasets.split_dataset_randomtorch.utils.data.random_splitと同じようなことをしてくれる。

Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。

from torch.utils.data.dataset import Subset

def split_dataset(data_set, split_at, order=None):
    from torch.utils.data.dataset import Subset
    n_examples = len(data_set)
    
    if split_at < 0:
        raise ValueError('split_at must be non-negative')
    if split_at > n_examples:
        raise ValueError('split_at exceeds the dataset size')
        
    if order is not None:
        subset1_indices = order[0:split_at]
        subset2_indices = order[split_at:n_examples]
    else:
        subset1_indices = list(range(0,split_at))
        subset2_indices = list(range(split_at,n_examples))
        
    subset1 = Subset(data_set, subset1_indices)
    subset2 = Subset(data_set, subset2_indices)
    
    return subset1, subset2

def split_dataset_random(data_set, first_size, seed=0):
    order = np.random.RandomState(seed).permutation(len(data_set))
    return split_dataset(data_set, int(first_size), order)

これを使うとランダムに分割できる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

参考

97
73
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
97
73

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?