LoginSignup
26
22

More than 1 year has passed since last update.

PyTorchで層化分割(Stratified Split)を行う

Last updated at Posted at 2020-05-11

層化分割(Stratified Split)とは

機械学習をしていると、データセットを学習用データとバリデーション用データに分割することがよくあります。特に分類問題の場合、クラスラベルを考慮せずランダムに分割してもいいのですが、分割後のデータのクラスラベルの分布が元データと同じになるように分割するのが望ましいです。このように各クラスの比率を保ったまま分割することを、層化抽出とか層化分割(Stratified Split)と言います。

PyTorchでの実装例

scikit-learnではsklearn.model_selection.train_test_split()という関数にstratifyオプションを渡すことでStratified Splitを行うことができます。

一方、PyTorchにはそのような仕組みがありません。torch.utils.data.random_split()のような関数を使えばデータセットをランダムに分割することはできますが、ストレートにStratified Splitを行うことはできません。そこで、scikit-learnのtrain_test_split()と組み合わせることで、Stratified Splitを実現します。

例えば、次のようなコードでStratified Splitを行うことができます。

import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

transformer = transforms.Compose([
    transforms.ToTensor(),
])

# 画像を読み込む
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

# データセットをtrainとvalidationに分割
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

# DataLoaderを作成
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

順に説明します。

初めに、ImageFolderで画像を読み込んでDatasetを作成しています。

transformer = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

次にtrain_test_split()でデータを分割しますが、Datasetを直接渡すことはできないので、list(range(len(dataset.targets)))でDatasetのインデックス配列([0,1,2,3,...データ数])を生成し、それを代わりに渡します。そして、このインデックス配列に対するクラスラベルdataset.targetsstratifyオプションとして渡すことで、元データのクラスラベルの比率を保ったまま、インデックス配列を学習用とバリデーション用に分割することができます。

train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)

分割したのはあくまでインデックス配列なので、そのインデックスを元にデータセットを分割します。Subsetはその名の通りデータのサブセットを作るためのクラスで、元となるDatasetとインデックス配列を渡すことで、インデックスに対応するDatasetを生成できます。

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

あとはいつも通りDataLoaderにDatasetを渡してあげるだけです。

# DataLoaderを作成
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

参考サイト

26
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
22