LoginSignup
40
23

More than 1 year has passed since last update.

画像のデータセットを最も楽に作る方法【Pytorch】

Posted at

Pytorchの学習の大まかな学習の流れは、

  1. Datasetを作る
  2. DatasetをDataloaderに取り込む
  3. 学習する

Datasetさえできていれば2はかなり簡単で1行で完結します。

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=10000, shuffle=True)

一方、Datasetの作り方は記事を検索してもなかなか出てこず苦労しました。
そんなわけで、Pytorchで『最も簡単に』データセットを作成する方法を解説します。
ものすごく簡単でした。簡単すぎるからあまり記事が出てこなかったのかもしれません。

ImageFolderで一発で作れる

Datasetはtorchvision.datasets.ImageFolderで一発で作れます。
猫の画像と犬の画像を分類したい場合を考えます。以下のようにディレクトリを作ります。

ディレクトリ構成

root/
├ data/
│   ├ train/
│   │   ├ cat/
│   │   │   ├ cat01_train.jpg
│   │   │   ・・・
│   │   └ dog/
│   │       ├ dog01_train.jpg
│   │       ・・・
│   └ val/
│       ├ cat/
│       │   ├ cat01_val.jpg
│       │   ・・・
│       └ dog/
│           ├ dog01_val.jpg
│           ・・・
model.py

train用のフォルダとvalidation用のフォルダを作り、その中に、分類したい画像を格納します。
準備はこれだけ!
あとはmodel.pyでデータセットを作るコードを定義します。

model.py
import torch
import torchvision
from torchvision import transforms

# image_sizeやmean, stdはデータに合わせて設定してください。
image_size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

# trainデータとvalidationデータが入っているディレクトリのパスを指定
train_image_dir = './data/train'
val_image_dir = './data/val'

# trainデータ向けとvalidationデータ向けに、transformを用意します。
# 皆さんのやりたいことに合わせて適宜変更してください。
data_transform = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(
            image_size, scale=(0.5, 1.0)
        ),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=[-15, 15]),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        transforms.RandomErasing(0.5),
    ]),
    'val': transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
}

# torchvision.datasets.ImageFolderでデータの入っているディレクトリのパスと
# transformを指定してあげるだけ。
train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'])
val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'])

# Datasetができたら、dataloaderに渡してあげればOK
batch_size = 32
train_dataLoader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_dataLoader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

# 以降はモデルを設定して学習・・・(割愛)

データセットを作っているのはこの数行です。

train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transform['train'])
val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transform['val'])

たったこれだけで画像のデータセットを自作できます。

ImageFolderを使わないと、Datasetクラスを自分で定義してうんぬんかんぬん、、、とかなり手間のかかることをやらなければいけません。
ですが、この方法だと一発で作れるのでぜひ試してみてください。

40
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
40
23