56
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorchでクロスバリデーション(交差検証)

Last updated at Posted at 2020-08-31

はじめに

PyTorch で Dataset を使用するときのクロスバリデーション(交差検証)のやり方を説明します。

Subsetを使用した分割

torch.utils.data.dataset.Subsetを使用するとインデックスを指定してDatasetを分割することが出来ます。これとscikit-learnのsklearn.model_selectionを組み合わせます。

train_test_split

sklearn.model_selection.train_test_splitを使用してインデックスをtrain_indexvalid_indexに分割し、Subsetを使用してDatasetを分割します。

class DummyDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        image = f"dummy image {idx}"
        label = f"dummy label {idx}"
        return image, label


dataset = DummyDataset()

seed = 0
train_index, valid_index = train_test_split(
    range(len(dataset)),
    test_size=0.3,
    random_state=seed
)

train_dataset = Subset(dataset, train_index)
valid_dataset = Subset(dataset, valid_index)

batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

# ここに学習コード
for imgs, labels in train_dataloader:
    print(imgs)  # ('dummy image 55', 'dummy image 42', 'dummy image 23', 'dummy image 41')  
    print(labels)  # ('dummy label 55', 'dummy label 42', 'dummy label 23', 'dummy label 41')
    break  # dummy

KFoldクロスバリデーション

sklearn.model_selection.KFoldを使用してインデックスをtrain_indexvalid_indexに分割し、Subsetを使用してDatasetを分割します。

from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset


class DummyDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        image = f"dummy image {idx}"
        label = f"dummy label {idx}"
        return image, label


dataset = DummyDataset()

kf = KFold(n_splits=3)
for _fold, (train_index, valid_index) in enumerate(kf.split(range(len(dataset)))):
    train_dataset = Subset(dataset, train_index)
    valid_dataset = Subset(dataset, valid_index)

    batch_size = 4
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

    num_epochs = 10
    for i in range(num_epochs):
        for imgs, labels in train_dataloader:
            print(imgs)
            print(labels)
            break  # dummy
        break  # dummy
    break  # dummy

クラス分類のDatasetであればdataset[:][1]とすればyの値を取得することができるはずなので、StratifiedKFoldもできるはずです。[dataset.__getitem__(i)[1] for i in range(len(dataset))]でも良いです。

DataFrameを使用した分割

ディープラーニングは時間がかかるため、ミスがないかチェックしつつ実験をするために1回のスクリプト実行でクロスバリデーションの1fold分だけ行うということもあると思います。その場合は予めアノテーションファイルをfoldに分割しておいて、それぞれのアノテーションファイルで学習をすればよいです。以下はアノテーションファイルの分割まで含めたコードです。

import os

import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold


class DummyDataset(Dataset):
    def __init__(
        self,
        annotation_df,
        img_dir,
        transform=None,
        target_transform=None
    ):
        self.records = annotation_df.to_dict(orient="records")
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]

        img_path = os.path.join(self.img_dir, record["filename"])
        # image = Image.open(img_path)
        image = img_path  # dummy

        label = record["label"]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


annotation_df = pd.DataFrame({
    "filename": ["image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"],
    "label": [0, 1, 0, 1]
})

# ホールド・アウトの場合
# seed = 0
# train_index, val_index = train_test_split(
#     range(len(annotation_df)),
#     test_size=0.3,
#     random_state=seed
# )

# K-foldの場合
kf = KFold(n_splits=3)
split_n = 0  # 実験対象のfold
train_index, val_index = list(kf.split(range(len(annotation_df))))[split_n]

train_df = annotation_df.iloc[train_index]
val_df = annotation_df.iloc[val_index]
train_dataset = DummyDataset(train_df, "data")
valid_dataset = DummyDataset(val_df, "data")

train_batch_size = 4
val_batch_size = 4
train_dataloader = DataLoader(train_dataset, train_batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, val_batch_size, shuffle=False)

# ここに学習コード
for imgs, labels in train_dataloader:
    print(imgs)
    print(labels)
    break  # dummy

以上

56
46
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
56
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?