Pytorchの学習の大まかな学習の流れは、
- Datasetを作る
- DatasetをDataloaderに取り込む
- 学習する
Datasetさえできていれば2はかなり簡単で1行で完結します。
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10000, shuffle=True)
一方、Datasetの作り方は記事を検索してもなかなか出てこず苦労しました。
そんなわけで、Pytorchで『最も簡単に』データセットを作成する方法を解説します。
ものすごく簡単でした。簡単すぎるからあまり記事が出てこなかったのかもしれません。
ImageFolderで一発で作れる
Datasetはtorchvision.datasets.ImageFolder
で一発で作れます。
猫の画像と犬の画像を分類したい場合を考えます。以下のようにディレクトリを作ります。
ディレクトリ構成
root/
├ data/
│ ├ train/
│ │ ├ cat/
│ │ │ ├ cat01_train.jpg
│ │ │ ・・・
│ │ └ dog/
│ │ ├ dog01_train.jpg
│ │ ・・・
│ └ val/
│ ├ cat/
│ │ ├ cat01_val.jpg
│ │ ・・・
│ └ dog/
│ ├ dog01_val.jpg
│ ・・・
model.py
train用のフォルダとvalidation用のフォルダを作り、その中に、分類したい画像を格納します。
準備はこれだけ!
あとはmodel.pyでデータセットを作るコードを定義します。
model.py
import torch
import torchvision
from torchvision import transforms
# image_sizeやmean, stdはデータに合わせて設定してください。
image_size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# trainデータとvalidationデータが入っているディレクトリのパスを指定
train_image_dir = './data/train'
val_image_dir = './data/val'
# trainデータ向けとvalidationデータ向けに、transformを用意します。
# 皆さんのやりたいことに合わせて適宜変更してください。
data_transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(
image_size, scale=(0.5, 1.0)
),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(degrees=[-15, 15]),
transforms.ToTensor(),
transforms.Normalize(mean, std),
transforms.RandomErasing(0.5),
]),
'val': transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
}
# torchvision.datasets.ImageFolderでデータの入っているディレクトリのパスと
# transformを指定してあげるだけ。
train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'])
val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'])
# Datasetができたら、dataloaderに渡してあげればOK
batch_size = 32
train_dataLoader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
val_dataLoader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False
)
# 以降はモデルを設定して学習・・・(割愛)
データセットを作っているのはこの数行です。
train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'])
val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'])
たったこれだけで画像のデータセットを自作できます。
ImageFolderを使わないと、Datasetクラスを自分で定義してうんぬんかんぬん、、、とかなり手間のかかることをやらなければいけません。
ですが、この方法だと一発で作れるのでぜひ試してみてください。