概要
PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasets
を使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。
例題としてtorchvision.datasets.MNIST
を使う。
課題
torchvision.datasets.MNIST
を使うと簡単にPyTorchのDataset
を作ることができるが、train/test用のDataset
しか用意されていないためvalidation用のDataset
を自分で作る必要がある。
以下のコードはtrain/test用のDataset
を作っている。
from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
test_dataset = datasets.MNIST('./', train=False, download=True)
print(len(trainval_dataset)) # 60000
print(len(test_dataset)) # 10000
print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST
print(type(test_dataset)) # torchvision.datasets.mnist.MNIST
このtrainval_dataset
をtrain/validationに分割したい。
しかし、trainval_dataset
は単純なリスト形式ではなく、PyTorchのDataset
になっているため、「Dataset
が持つデータを取り出して、それをDataset
クラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、Dataset
クラスにTransform
クラスを渡している場合。)
Dataset
クラスとTransform
クラスについては以下の記事にまとめました。
PyTorch transforms/Dataset/DataLoaderの基本動作を確認する
解決策1 torch.utils.data.Subset
torch.utils.data.Subset(dataset, indices)
を使うと簡単にDataset
を分割できる。
PyTorchの中のコードは以下のようにシンプルなクラスになっている。
class Subset(Dataset):
"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
つまり、Dataset
とインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDataset
を生成してくれる。
文章にするとややこしいけどコードの例を見るとわかりやすい。
以下のコードはMNISTの60000のDataset
をtrain:48000とvalidation:12000のDataset
に分割している。
from torchvision import datasets
from torch.utils.data.dataset import Subset
trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000
subset1_indices = list(range(0,train_size)) # [0,1,.....47999]
subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999]
train_dataset = Subset(trainval_dataset, subset1_indices)
val_dataset = Subset(trainval_dataset, subset2_indices)
print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000
解決策2 torch.utils.data.random_split
解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)
を使うとランダムに分割することができる。
from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000
val_size = n_samples - train_size # val_size is 48000
# shuffleしてから分割してくれる.
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])
print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000
Chainerのchainer.datasets.split_dataset_random
について
ちなみにChainerのchainer.datasets.split_dataset_random
がtorch.utils.data.random_split
と同じようなことをしてくれる。
Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。
from torch.utils.data.dataset import Subset
def split_dataset(data_set, split_at, order=None):
from torch.utils.data.dataset import Subset
n_examples = len(data_set)
if split_at < 0:
raise ValueError('split_at must be non-negative')
if split_at > n_examples:
raise ValueError('split_at exceeds the dataset size')
if order is not None:
subset1_indices = order[0:split_at]
subset2_indices = order[split_at:n_examples]
else:
subset1_indices = list(range(0,split_at))
subset2_indices = list(range(split_at,n_examples))
subset1 = Subset(data_set, subset1_indices)
subset2 = Subset(data_set, subset2_indices)
return subset1, subset2
def split_dataset_random(data_set, first_size, seed=0):
order = np.random.RandomState(seed).permutation(len(data_set))
return split_dataset(data_set, int(first_size), order)
これを使うとランダムに分割できる。
from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000
train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0)
print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000