15
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでDatasetの読み込みを実装してみた

Posted at

#概要

機械学習の勉強を始めてまだ半年の私がなんとかPyTorchでDatasetを作ったので覚え書きついでに投稿します。
GANの勉強でGitHubからコードを落として勉強していたのですが、MNISTやCIFARの読み込みしかしてなかったため、自分の持っているデータセットで実行したくなってDatasetを自作しました。
(記事投稿の練習もかねての記事なのであしからず。。。)

#環境

  • 実行環境:PyCharm
  • python:3.6(Anaconda)
  • torch:1.3.1
  • torchvision:0.4.2

Datasetの必須条件

  • PyTorchのDatasetの継承
    学習モデルに渡す際に、DataLoaderにこのDataset継承クラスのオブジェクトを渡すため

  • __getitem__と__len__のメソッド
    __getitem__はデータとラベルをタプルで返すメソッド
    __len__はそのまま意味で、データ数を返すメソッド

なので、基本的な構成としてはこうなります。


class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

クラスへの引数としてはデータへのPath以外に、画像の入力サイズと前処理のためのtransformを渡しました。

#コンストラクタの定義
クラス生成時に自動で呼び出されるコンストラクタでは、以下の処理を行います。

  • transformの定義
  • すべてのデータパスの読み込み
  • クラス名の定義及び辞書化
    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize), # 画像のリサイズ
            transforms.ToTensor(), # Tensor化
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 標準化
        ])

        # ここに入力データとラベルを入れる
        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png")]

        self.data_num = len(self.image_paths) # ここが__len__の返り値になる
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}

マテリアル多分類のデータを持っていたので、それを使用しました。

#__getitem__の定義
__getitem__は学習時にデータとその正解ラベルを読み込むためのメソッドなので、コンストラクタで読み込んだ情報を使って実装していきます。


    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

画像データはコンストラクタで読み込んでもいいと思いますが、データ数が多いとメモリが心配だったのでその都度読み込むことにしました。
クラスラベルにおいてもわざわざ辞書化するというちょっとめんどくさい方法を使ってます。

DataLoaderへの引き渡し

実際にコード中で読み込むときは、以下のように使えば学習に使えます。
(DataLoaderの引数shuffleは、dataの参照の仕方をランダムにする)

    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

まとめソースコード

import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png")]

        self.data_num = len(self.image_paths)
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}


    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

if __name__ == "__main__":
    root_data = 'データへのPath'
    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

参考サイト

以下のサイトを見ながら実装しました。
ありがとうございました。
pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用
PyTorch: DatasetとDataLoader (画像処理タスク編)

15
8
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?