0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchに入門しよう② DatasetとDataLoader

Posted at

前回の記事

PyTorchに入門しよう① Tensorの使い方

はじめに

今回はDatasetとDataLoaderとは何か、どのようにして使用するか、
そして、FashionMNISTのDatasetをロードする方法とカスタムデータセットの作成方法を学びます。

DatasetとDataLoader

データサンプルを処理するコードというのは煩雑になりやすく、メンテナスも大変です。
これを解消するためには、データセットのコードをモデルの学習コードから切り離して管理するのが理想的です。
PyTorchでは以下の2つのデータセットプリミティブが提供されており、プリロードされたデータセットや独自のデータを使用することができます。

torch.utils.data.Dataset
サンプルデータとそれに対応するラベルを格納する
torch.utils.data.DataLoader
Datasetをラップして、イテレート可能にする

PyTorchのドメインライブラリで提供される多くのプリロードデータセットはtorch.utils.data.Datasetのサブクラスとして提供されています。
これらのデータセットはモデルのプロトタイプやベンチマークに使用することが推奨されています。
詳細な一覧は以下を参照してください
Image Datasets
Text Datasets
Audio Datasets

Datasetのロード

ここでは、TorchVisionからFashion-MNISTデータセットをロードする方法を記載します。
Fashion-MNISTは、60,000のトレーニングデータと10,000のテストデータを持つ、28 × 28のグレースケール画像です。
また、クラスは10クラスに分類されています。

各引数が何を表しているかは以下のとおりです。

  • root : トレーニングとテストデータが格納されているパスを指定する
  • train : トレーニングとテストのどちらのデータセットを使用するかを指定する
  • download : rootに対象のデータが存在しない場合にダウンロードするかどうか
  • transform : ラベルと特徴量をどのように変換するかを指定する
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Datasetの可視化

Datasetにはtraining_data[idx]のようにして、手動でアクセスすることができます。
以下のコードでは、Datasetからいくつかのデータを取得してmatplotlibで可視化しています。
軽く説明すると、まずtorch.randint()で0からlen(training_data)までの範囲のランダムな整数を1つ持つTensorを生成し、item()でそれをint型に変換します。
その後、training_data[sample_idx]にて先ほど取得した乱数をインデックスとしてDatasetのデータを取得しています。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

ローカルファイルからカスタムデータセットを作成する

カスタムデータセットを作成するには次の3つの関数を実装する必要があります。
各関数の詳細は次のセクションで解説します。

  1. __init__
  2. __len__
  3. __getitem__

__init__

__init__はDatasetオブジェクトをインスタンス化した際に一度だけ実行されます。
このメソッドで画像を保存しているディレクトリやアノテーションファイル、変換関数の初期化を行います。

__len__

データセットに含まれるサンプル数を返すメソッドです。

__getitem__

__getitem__はDatasetから与えられたidxに対応するサンプルを返す関数です。
transformが設定されている場合は、ここでイメージとラベルの変換処理を行います。
 

下のCustomImageDatasetでは、annotations_fileで指定したcsvファイルをラベルとして読み込み、
img_dirで指定されたディレクトリをイメージの保存先ディレクトリとして読み込みます。
ここで指定するcsvファイルは下のような構造にします。

image1.png,0
image2.png,3
...
custom_dataset.py
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision.io import read_image

# カスタムデータセットの作成
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx,0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx,1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

dataset = CustomImageDataset(
    annotations_file = 'data/CustomImage/labels.csv',
    img_dir = 'data/CustomImage/images',
)

# pyplotによる画像出力
img, label = dataset[0]

figure = plt.figure()
figure.add_subplot(1,1,1)
plt.title(label)
plt.axis("off")
plt.imshow(img.permute(1,2,0))
plt.show()

補足:plt.imshow()にてimg.permute()を実行している理由
 matplotlibのimshow()メソッドでRGB画像を表示するには、画像を(x,y,3)の次元にしておく必要がありますが、
 read_image()で読み込んだTensorは(3,x,y)の次元になっています。
そのため、permute()メソッドを使用して、Tensorの次元を入れ替えています。permute()は引数で指定された次元の順に入れ替えを行うメソッドです。

# permuteの例
t = torch.tensor([[1,2,3],[4,5,6]])
t.shape  # (2,3)
t2 = t.permute(1,0)
t2       # tensor([[1,4], [2,5], [3,6]])
t2.shape # (3,2)

DataLoaderを使用してトレーニングデータを準備する

先ほどはDatasetを使用して一つずつサンプルを取得しましたが、実際にモデルの学習を行う際はミニバッチ単位でサンプルを取得して、過学習を防ぐためにエポック毎にデータをランダムに取得します。
DataLoaderはこの複雑な処理を簡単なAPIで抽象化してくれるイテレータとなっています。
DataLoaderの生成方法は下記のようになっています。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader  = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoaderを使用してイテレートする

DataLoaderはiter()を使用することでイテレータに変換できます。
下のコードではnext(iter(train_dataloader))とすることでミニバッチの始めの要素を取得しています。

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")  # torch.Size([64, 1, 28, 28])
print(f"Labels batch shape: {train_labels.size()}")     # torch.Size([64])
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

あとがき

 Datasetの作成方法やDataLoaderによるミニバッチ用データの取得方法は理解できたでしょうか。
今回の記事はこれで終わりです。次はモデルの学習に使用するTransformについて記載していきますので、
ゆっくり待っていただけると幸いです。
 最後まで読んだ方はご指摘やいいねをぜひお願いいたします。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?