概要
Pytorchでデータセット作成時に使用するDataset/DataLoaderあたりをメモ
参考:
データの前処理
データの前処理についてはtorchvision.transforms
またはalbumentations
あたりのライブラリがある。
どちらも基本的な動作は同じ。
前処理のクラスインスタンスをリストに詰め込んで、Compose()
の引数にしてインスタンス作成。
Compose
は、__call__(self, img)
メソッドを持つので、作成したインスタンスの引数に画像を入れれば前処理される。
import albumentations as albu
def get_augmentation(phase):
""" Get augmentation for each phase
Args:
phase (str): train or val
Returns:
albumentations.Compose: Composed transforms
"""
transform_list = []
if phase == 'train':
transform_list.extend([albu.HorizonFlip(p=0.5),
albu.VerticalFlip(p=0.5)])
transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
p=1),
albu.ToTensor()
])
return albu.Compose(transform_list)
Dataset
入力データとそれに対応するラベルを1組ずつ取ってくるモジュール。データの前処理をする場合は、transformsを用いて前処理をかけたデータを返すようにする。
<必要条件>
-
Datset
の継承 -
__getitem__
、__len__
の実装
以上を満たしていれば基本的にはOK!
Dataset継承クラスのインスタンスがDataLoderの第1引数となる。(DataLodaerについては後で)
例えばデータセットが以下のようなディレクトリ構成だとする。
datasets/ ____ train_images/
|__ test_images/
|__ train.csv
今回はデータセットに対して、.csvファイルがデータパスとラベルの情報を持っていることを想定してる。
import os.path as osp
import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class MyDataManager(Dataset):
"""My Dataset
Args:
root(str): root path of dataset directory
df(DataFrame): DataFrame object from csv file
phase(str): train or test
"""
def __init__(self, root, df, phase):
super(MyDataManager, self).__init__()
self.root = root
self.df = df
self.phase = phase
self.transfoms = get_augmentation()
def __getitem__(self, idx):
img_path = osp.join(self.root, self.df.iloc[idx].name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(image=img)
label = osp.join(eslf.root, self.df.iloc[idx].value)
ret = {'image': img, 'label': label}
return ret
def __len__(self):
return len(self.df)
今回は、戻り値をdict型にしているが、return image, label
としても問題ない。
Segmentationを行う場合などは、ラベルをマスク画像として与える必要があるので、その場合はマスク画像もtransfomする。
DataLoader
Datsetで取ってきたデータをDataLoaderの引数とすればいい。
DataLoaderの引数構造は以下、
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
したがって、以下のような関数を作成する。
def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
df_path = osp.join(dir_path, 'train.csv')
df = pd.read_csv(df_path)
dataset = MyDataManager(dir_path, df, phase)
dl = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=shuffle)
return dl
※まだまだ知識不足な点があるかもしれないのでコメントで教えていただけると幸いです。