LoginSignup
2
1

More than 3 years have passed since last update.

交差検証 CombinatorialPurgedCV Python [備忘録]

Last updated at Posted at 2021-02-18

背景

Kaggle [Jane Street Market Prediction] にてfinanceでよく利用される交差検証を試した.

ノイズの影響が大きく、LBも信用できないので交差検証が非常に大切だった.

概要

  • 赤:train
  • 青:validation

図は全体の1/4を検証データとし、trainとvalidationの間に10dayのギャップを設定している.

cpkf-image.png

実装

簡単のためN_COMB=2で固定.

import numpy as np
import pandas as pd
import itertools
class CombinatorialPurgedKFold:
    N_COMB = 2

    def __init__(self, data: pd.DataFrame, date_col_nm: str, n_block: int, n_gap_day: int = 5):
        self.n_block = n_block
        self.n_gap_day = n_gap_day

        self.uni_date = np.unique(data[date_col_nm].values)
        self.date_blocks = np.array_split(self.uni_date, self.N_COMB * n_block)
        self.valid_comb = list(itertools.combinations(list(range(self.N_COMB * n_block)), self.N_COMB))

        if np.min([len(db) for db in self.date_blocks]) < self.n_gap_day * 2:
            raise ValueError()

        self.__i = -1
        self.__data = data
        self.__date_col_nm = date_col_nm

    def __iter__(self):
        self.__i = -1
        return self

    def __next__(self):
        self.__i += 1
        if self.__i < 0 or len(self.valid_comb) <= self.__i:
            raise StopIteration()
        return self.splits()

    def splits(self):
        val_idx1, val_idx2 = self.valid_comb[self.__i]

        val_dates = list()
        s1 = 0 if val_idx1 == 0 else self.n_gap_day
        e2 = len(self.date_blocks[val_idx2]) if val_idx2 == (len(self.date_blocks) - 1) else -self.n_gap_day
        s2, e1 = (0, len(self.date_blocks[val_idx1])) if (val_idx1 + 1) == val_idx2 else (self.n_gap_day, -self.n_gap_day)

        val_dates += self.date_blocks[val_idx1][s1:e1].tolist()
        val_dates += self.date_blocks[val_idx2][s2:e2].tolist()     
        tra_dates = [d for d in self.uni_date if d not in (self.date_blocks[val_idx1].tolist() + self.date_blocks[val_idx2].tolist())]

        tra_df, val_df = self.__data.query(f'{self.__date_col_nm} in {tra_dates}'), self.__data.query(f'{self.__date_col_nm} in {val_dates}')
        return tra_df, val_df

使用例

cpkf = CombinatorialPurgedKFold(data=train, date_col_nm='date', n_block=4, n_gap_day=10)

for tr_df, va_df in cpkf:
    print(tr_df, va_df)

参考

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1