LoginSignup
4
3

More than 3 years have passed since last update.

【Pytorch】Dataset/DataLoaderについてメモ

Last updated at Posted at 2020-08-01

概要

Pytorchでデータセット作成時に使用するDataset/DataLoaderあたりをメモ

参考:

データの前処理

データの前処理についてはtorchvision.transformsまたはalbumentationsあたりのライブラリがある。
どちらも基本的な動作は同じ。
前処理のクラスインスタンスをリストに詰め込んで、Compose()の引数にしてインスタンス作成。
Composeは、__call__(self, img)メソッドを持つので、作成したインスタンスの引数に画像を入れれば前処理される。

import albumentations as albu

def get_augmentation(phase):
   """ Get augmentation for each phase

   Args:
      phase (str): train or val

   Returns:
      albumentations.Compose: Composed transforms
   """
   transform_list = []
   if phase == 'train':
        transform_list.extend([albu.HorizonFlip(p=0.5),
                             albu.VerticalFlip(p=0.5)])
   transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225),
                                         p=1),
                          albu.ToTensor()
                        ])
    return albu.Compose(transform_list)

Dataset

入力データとそれに対応するラベルを1組ずつ取ってくるモジュール。データの前処理をする場合は、transformsを用いて前処理をかけたデータを返すようにする。

<必要条件>

  • Datsetの継承
  • __getitem____len__の実装

以上を満たしていれば基本的にはOK!
Dataset継承クラスのインスタンスがDataLoderの第1引数となる。(DataLodaerについては後で)

例えばデータセットが以下のようなディレクトリ構成だとする。

datasets/ ____ train_images/
           |__ test_images/
           |__ train.csv

今回はデータセットに対して、.csvファイルがデータパスとラベルの情報を持っていることを想定してる。

import os.path as osp

import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class MyDataManager(Dataset):
    """My Dataset

    Args:
        root(str): root path of dataset directory
        df(DataFrame): DataFrame object from csv file
        phase(str): train or test
    """

    def __init__(self, root, df, phase):
       super(MyDataManager, self).__init__()
       self.root = root
       self.df = df
       self.phase = phase
       self.transfoms = get_augmentation()

    def __getitem__(self, idx):
       img_path = osp.join(self.root, self.df.iloc[idx].name)
       img = cv2.imread(img_path)
       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       img = self.transform(image=img)

       label = osp.join(eslf.root, self.df.iloc[idx].value)

       ret = {'image': img, 'label': label}

       return ret

    def __len__(self):
       return len(self.df)

今回は、戻り値をdict型にしているが、return image, labelとしても問題ない。
Segmentationを行う場合などは、ラベルをマスク画像として与える必要があるので、その場合はマスク画像もtransfomする。

DataLoader

Datsetで取ってきたデータをDataLoaderの引数とすればいい。
DataLoaderの引数構造は以下、

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

したがって、以下のような関数を作成する。


def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
    df_path = osp.join(dir_path, 'train.csv')
    df = pd.read_csv(df_path)

    dataset = MyDataManager(dir_path, df, phase)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    shuffle=shuffle)

    return dl

※まだまだ知識不足な点があるかもしれないのでコメントで教えていただけると幸いです。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3