層化分割(Stratified Split)とは
機械学習をしていると、データセットを学習用データとバリデーション用データに分割することがよくあります。特に分類問題の場合、クラスラベルを考慮せずランダムに分割してもいいのですが、分割後のデータのクラスラベルの分布が元データと同じになるように分割するのが望ましいです。このように各クラスの比率を保ったまま分割することを、層化抽出とか層化分割(Stratified Split)と言います。
PyTorchでの実装例
scikit-learnではsklearn.model_selection.train_test_split()
という関数にstratify
オプションを渡すことでStratified Splitを行うことができます。
一方、PyTorchにはそのような仕組みがありません。torch.utils.data.random_split()
のような関数を使えばデータセットをランダムに分割することはできますが、ストレートにStratified Splitを行うことはできません。そこで、scikit-learnのtrain_test_split()
と組み合わせることで、Stratified Splitを実現します。
例えば、次のようなコードでStratified Splitを行うことができます。
import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
transformer = transforms.Compose([
transforms.ToTensor(),
])
# 画像を読み込む
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)
# データセットをtrainとvalidationに分割
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
# DataLoaderを作成
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)
順に説明します。
初めに、ImageFolderで画像を読み込んでDatasetを作成しています。
transformer = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)
次にtrain_test_split()
でデータを分割しますが、Datasetを直接渡すことはできないので、list(range(len(dataset.targets)))
でDatasetのインデックス配列([0,1,2,3,...データ数]
)を生成し、それを代わりに渡します。そして、このインデックス配列に対するクラスラベルdataset.targets
をstratify
オプションとして渡すことで、元データのクラスラベルの比率を保ったまま、インデックス配列を学習用とバリデーション用に分割することができます。
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
分割したのはあくまでインデックス配列なので、そのインデックスを元にデータセットを分割します。Subsetはその名の通りデータのサブセットを作るためのクラスで、元となるDatasetとインデックス配列を渡すことで、インデックスに対応するDatasetを生成できます。
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
あとはいつも通りDataLoaderにDatasetを渡してあげるだけです。
# DataLoaderを作成
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)
参考サイト