はじめに
PyTorch で Dataset を使用するときのクロスバリデーション(交差検証)のやり方を説明します。
Subsetを使用した分割
torch.utils.data.dataset.Subset
を使用するとインデックスを指定してDatasetを分割することが出来ます。これとscikit-learnのsklearn.model_selection
を組み合わせます。
train_test_split
sklearn.model_selection.train_test_split
を使用してインデックスをtrain_index
とvalid_index
に分割し、Subset
を使用してDatasetを分割します。
class DummyDataset(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
image = f"dummy image {idx}"
label = f"dummy label {idx}"
return image, label
dataset = DummyDataset()
seed = 0
train_index, valid_index = train_test_split(
range(len(dataset)),
test_size=0.3,
random_state=seed
)
train_dataset = Subset(dataset, train_index)
valid_dataset = Subset(dataset, valid_index)
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
# ここに学習コード
for imgs, labels in train_dataloader:
print(imgs) # ('dummy image 55', 'dummy image 42', 'dummy image 23', 'dummy image 41')
print(labels) # ('dummy label 55', 'dummy label 42', 'dummy label 23', 'dummy label 41')
break # dummy
KFoldクロスバリデーション
sklearn.model_selection.KFold
を使用してインデックスをtrain_index
とvalid_index
に分割し、Subset
を使用してDatasetを分割します。
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
class DummyDataset(Dataset):
def __init__(self):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
image = f"dummy image {idx}"
label = f"dummy label {idx}"
return image, label
dataset = DummyDataset()
kf = KFold(n_splits=3)
for _fold, (train_index, valid_index) in enumerate(kf.split(range(len(dataset)))):
train_dataset = Subset(dataset, train_index)
valid_dataset = Subset(dataset, valid_index)
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
num_epochs = 10
for i in range(num_epochs):
for imgs, labels in train_dataloader:
print(imgs)
print(labels)
break # dummy
break # dummy
break # dummy
クラス分類のDatasetであればdataset[:][1]
とすればy
の値を取得することができるはずなので、StratifiedKFold
もできるはずです。[dataset.__getitem__(i)[1] for i in range(len(dataset))]
でも良いです。
DataFrameを使用した分割
ディープラーニングは時間がかかるため、ミスがないかチェックしつつ実験をするために1回のスクリプト実行でクロスバリデーションの1fold分だけ行うということもあると思います。その場合は予めアノテーションファイルをfoldに分割しておいて、それぞれのアノテーションファイルで学習をすればよいです。以下はアノテーションファイルの分割まで含めたコードです。
import os
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
class DummyDataset(Dataset):
def __init__(
self,
annotation_df,
img_dir,
transform=None,
target_transform=None
):
self.records = annotation_df.to_dict(orient="records")
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
img_path = os.path.join(self.img_dir, record["filename"])
# image = Image.open(img_path)
image = img_path # dummy
label = record["label"]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
annotation_df = pd.DataFrame({
"filename": ["image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"],
"label": [0, 1, 0, 1]
})
# ホールド・アウトの場合
# seed = 0
# train_index, val_index = train_test_split(
# range(len(annotation_df)),
# test_size=0.3,
# random_state=seed
# )
# K-foldの場合
kf = KFold(n_splits=3)
split_n = 0 # 実験対象のfold
train_index, val_index = list(kf.split(range(len(annotation_df))))[split_n]
train_df = annotation_df.iloc[train_index]
val_df = annotation_df.iloc[val_index]
train_dataset = DummyDataset(train_df, "data")
valid_dataset = DummyDataset(val_df, "data")
train_batch_size = 4
val_batch_size = 4
train_dataloader = DataLoader(train_dataset, train_batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, val_batch_size, shuffle=False)
# ここに学習コード
for imgs, labels in train_dataloader:
print(imgs)
print(labels)
break # dummy
以上