#概要
機械学習の勉強を始めてまだ半年の私がなんとかPyTorchでDatasetを作ったので覚え書きついでに投稿します。
GANの勉強でGitHubからコードを落として勉強していたのですが、MNISTやCIFARの読み込みしかしてなかったため、自分の持っているデータセットで実行したくなってDatasetを自作しました。
(記事投稿の練習もかねての記事なのであしからず。。。)
#環境
- 実行環境:PyCharm
- python:3.6(Anaconda)
- torch:1.3.1
- torchvision:0.4.2
Datasetの必須条件
-
PyTorchのDatasetの継承
学習モデルに渡す際に、DataLoaderにこのDataset継承クラスのオブジェクトを渡すため -
__getitem__と__len__のメソッド
__getitem__はデータとラベルをタプルで返すメソッド
__len__はそのまま意味で、データ数を返すメソッド
なので、基本的な構成としてはこうなります。
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
クラスへの引数としてはデータへのPath以外に、画像の入力サイズと前処理のためのtransformを渡しました。
#コンストラクタの定義
クラス生成時に自動で呼び出されるコンストラクタでは、以下の処理を行います。
- transformの定義
- すべてのデータパスの読み込み
- クラス名の定義及び辞書化
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize), # 画像のリサイズ
transforms.ToTensor(), # Tensor化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 標準化
])
# ここに入力データとラベルを入れる
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png")]
self.data_num = len(self.image_paths) # ここが__len__の返り値になる
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
マテリアル多分類のデータを持っていたので、それを使用しました。
#__getitem__の定義
__getitem__は学習時にデータとその正解ラベルを読み込むためのメソッドなので、コンストラクタで読み込んだ情報を使って実装していきます。
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
画像データはコンストラクタで読み込んでもいいと思いますが、データ数が多いとメモリが心配だったのでその都度読み込むことにしました。
クラスラベルにおいてもわざわざ辞書化するというちょっとめんどくさい方法を使ってます。
DataLoaderへの引き渡し
実際にコード中で読み込むときは、以下のように使えば学習に使えます。
(DataLoaderの引数shuffleは、dataの参照の仕方をランダムにする)
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
まとめソースコード
import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png")]
self.data_num = len(self.image_paths)
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
def __len__(self):
return self.data_num
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
if __name__ == "__main__":
root_data = 'データへのPath'
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
参考サイト
以下のサイトを見ながら実装しました。
ありがとうございました。
pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用
PyTorch: DatasetとDataLoader (画像処理タスク編)