0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

初めてのPytorch #データセットとデータローダ

0
Posted at

データセットとデータローダー

PyTorchにはすでにいくつかテストデータが存在する
torch.utils.data.Dataset
画像データ
テキストデータセット
音声データセット
今回はこれらを使用方法を紹介する。

データセットの読み込み

TorchVisionからFashionMNISTのデータセットの読み込む方法
→Zalandoの商品画像で構成されたデータセットで28*28ピクセルの60000個のトレーニングサンプルと10000個のテキストサンプルからできる。

root
→トレーニングとテストデータが保存される
train
→トレーニングデータセットかテストデータセットを指定
download=True
→データが利用できない場合はインターネットからデータをダウンロードしrootに保存

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="train",
    train=False,
    download=True,
    transform=ToTensor()
)

データセットの反復処理と可視化

labels_map
→0~9までの画像をわかりやすいように辞書を使い名前付け
plt.figure
→表示画面のサイズを指定
torch.randint
→トレーニングデータからランダムに値を出力
.item()
→Pythonが扱いやすい値に変換
img, label = training_data[sample_idx]
→指定した番号の値と画像・正解値を取得

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

カスタムデータセットの作成

カスタムデータセットは__init__,__len__,__getitem__の3つの関数を実装する必要がある。

import os
import pandas as pd
from torchvision.io import decode_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = decode_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

インスタンス化する際に一度だけ実行される
この関数では画像、注釈ファイル、変換ファイルを保存するディレクトリを初期化する処理。

__len__

データセットないのサンプル数を返す。

__getitem__

  1. 取り出し(Fetch)
    指定された番号(idx)を元に、「どの画像ファイルを開くか」を特定し、ディスク(SSD/HDD)から画像を読み込みます。同時に、CSVなどの管理表からその画像に対応する「正解ラベル」もセットで取得します。

  2. 加工(Pre-process)
    読み込んだ生データを、AIが計算できる形式(テンソル)に変換します。必要に応じて、画像のサイズ変更や色の調整などの「変換処理(Transforms)」を自動的に適用します。

  3. 返却(Return)
    「加工済みの画像」と「正解ラベル」を1つのセット(タプル)にして返します。

データローダを使用したトレーニング用データの準備

モデルをトレーニングする際はサンプルをミニバッチで渡したりモデルの過学習を減らすために小分けにしてデータをシャッフルする必要があるがdataloaderを使用してAPIが勝手に実施してくれる
コードは以下

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

データローダーを反復処理する

データセットをdataloaderにロードし必要に応じてデータセットを反復処理できる。

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?