0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PytorchのDatasetを層状にfoldingする

Posted at

既存のDatasetをN-foldするときに、先頭から順に1,2,...N,1,2,...,N,1,2,...,N,1,2,...とデータセットを分割する。時系列データを分割するときに、特定の月や季節にデータが偏らないようにするときに使用する。

from torch.utils.data import Dataset


class LayeredFoldWrapper(Dataset):
    def __init__(self, dataset, n_splits=5, fold=0, valid=False):
        self.dataset = dataset
        self.n_splits = n_splits
        self.fold = fold
        self.valid = valid
        self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
        self.train_index = list(set(range(len(dataset))) - set(self.valid_index))

    def __len__(self):
        return len(self._get_index_list(self.valid))

    def __getitem__(self, i):
        return self.dataset.__getitem__(self._get_index_list(self.valid)[i])

    def _valid_index(self, N, n_splits, fold):
        """
        N: 全データの数
        n_splits: foldのスプリットの数
        fold: 各foldを指定する値 0<=fold<=n_splits-1
        """
        assert(0<=fold<=n_splits-1)
        return range(n_splits - fold - 1, N+1, n_splits)

    def _get_index_list(self, valid):
        if valid:
            return self.valid_index
        else:
            return self.train_index
0
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?