1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch StratifiedKFold+Dataset作成

Posted at

PyTorchではDatasetでデータセットを作成する際,k分割法と合わせて使う手順がscikit-learnとは異なり,少し癖があるのでメモ.

  1. train_dfにfoldカラムを追加
train_df["fold"] = 0
  1. stratifiedkfoldのval_indexを利用してfold_idを割り当てる.
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
for i, (train_index, val_index) in enumerate(kf.split(train_df, train_df["class_num"])): # kf.split(X, y)
    train_df.loc[val_index, "fold"] = int(i) # loc: 名前で参照
    print(f"FOLD: {i}, train: {len(train_index)}, val: {len(val_index)}")
train_df.head()
  1. データセット作成時に,2で作ったfoldカラムを参照して割り当てる.
fold = 0
for fold in range(5):
    train_dataset = MyDataset(train_df[train_df['fold']!=FOLD].reset_index(drop=True),transform=transforms_train, mode="train")
    valid_dataset = MyDataset(train_df[train_df['fold']==FOLD].reset_index(drop=True),transform=transforms_valid, mode="valid")
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?