既存のDataset
をN-foldするときに、先頭から順に1,2,...N,1,2,...,N,1,2,...,N,1,2,...とデータセットを分割する。時系列データを分割するときに、特定の月や季節にデータが偏らないようにするときに使用する。
from torch.utils.data import Dataset
class LayeredFoldWrapper(Dataset):
def __init__(self, dataset, n_splits=5, fold=0, valid=False):
self.dataset = dataset
self.n_splits = n_splits
self.fold = fold
self.valid = valid
self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
def __len__(self):
return len(self._get_index_list(self.valid))
def __getitem__(self, i):
return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
def _valid_index(self, N, n_splits, fold):
"""
N: 全データの数
n_splits: foldのスプリットの数
fold: 各foldを指定する値 0<=fold<=n_splits-1
"""
assert(0<=fold<=n_splits-1)
return range(n_splits - fold - 1, N+1, n_splits)
def _get_index_list(self, valid):
if valid:
return self.valid_index
else:
return self.train_index