6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch Advent Calendar 2020Advent Calendar 2020

Day 25

【PyTorch】Dataset & DataLorderで画像ペアを扱う

Last updated at Posted at 2020-12-25

#TL;DR

#画像がこのようなディレクトリに存在するとき
.
└─ my_img_path
   |
   ├── hogehoge
   |     |-img_A.png
   |     |-img_B.png
   |
   ├── fugafuga
   |     |-img_A.png
   |     |-img_B.png
   ...
   |
   └── piyopiyo
         |-img_A.png
         |-img_B.png

from torch.utils.data import DataLoader, Dataset
from PIL import Image
import glob
import os

#データセットの作成
class PairImgs(Dataset):
    # torch.utils.data の Dataset継承
    """
    self.img_paths   画像ペアの入ってるフォルダの一つ上のディレクトリへのパス
    self.imgs_list   画像ペアの入ってるフォルダすべてのList
    self.transform   指定したtransform
    """

    def __init__(self, img_dir, transform):
        self.img_paths = img_dir # 画像ペアを入れたディレクトリの親へのパスを指定する
        self.imgs_list = glob.glob(os.path.join(self.img_paths, "*"))
        self.transform = transform
    
    def __getitem__(self, index):
        # indexで指定したディレクトリ以下のファイルを返す
        
        # 画像をPILとして読み込む
        img_A = Image.open(os.path.join(self.imgs_list[index], "img_A.png"))  #画像の名前は任意に変更すること
        img_B = Image.open(os.path.join(self.imgs_list[index], "img_B.png"))

        if self.transform is not None:
            # 前処理がある場合は行う. 普通はtransforms.ToTensor()でTensorに変換してしまう
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return img_A , img_B
    
    def __len__(self):  # 画像ペアの入ってるフォルダの数 = データセット数を返す
        return len(glob.glob(self.img_paths + "*"))
# ** 利用時 **
# データセットを作成する
data_set = PairImgs("./my_img_path", transform=transforms.ToTensor())

# (必要なら)学習用とテスト用にデータセットを分割する
train_size = int( len(data_set) * 0.8 ) # 教師データのサイズ 全体の80%とする
test_size = n_samples - train_size  # テスト用データのサイズ

train_data, test_data = torch.utils.data.random_split(
        data_set,
        [train_size, test_size ],
        generator=torch.Generator().manual_seed(0)  # 乱数シードの固定
    )

# 作成したデータセットをデータローダ―に読み込ませる
batch_size = 32    # データローダ―のバッチサイズ 任意に変更のこと
# 学習データはシャッフルON テストデータはシャッフルなし
train_loader = DataLoader(train_data, batch_size=batch_size, 
                          shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False, num_workers=2)

#はじめに
PyTorch Advent Calendar 2020 Advent Calendar 2020の枠が空いてたので急遽魔がさして書きました。

解説

Pytorchで画像ペアを扱おうと考えた。この場合の画像ペアとは2枚の画像に関係がある(たとえば入力データと教師データ)ものであり、その2枚が一緒に取得できる必要があるものである。具体例としてはSRCNNなどのImage-to-Image 変換の際に利用する。
基本的には @mathlive さんのpyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用 を読んで作った。

Datasetの作成

class PairImgs(Dataset):
    # torch.utils.data の Dataset継承
    """
    self.img_paths   画像ペアの入ってるフォルダの一つ上のディレクトリへのパス
    self.imgs_list   画像ペアの入ってるフォルダすべてのList
    self.transform   指定したtransform
    """

    def __init__(self, img_dir, transform):
        self.img_paths = img_dir # 画像ペアを入れたディレクトリの親へのパスを指定する
        self.imgs_list = glob.glob(os.path.join(self.img_paths, "*"))
        self.transform = transform

---  ---

コンストラクタに引数として**「画像ペアが入っているフォルダの親ディレクトリ」**を渡す。今回は"./my_img_path"となる。
コンストラクタ内の処理としてglob.glob("./my_img_path/*")を実行して今回の対象となるフォルダをリストとして保存する。

img_pathsはこれ以降利用しないため保存しなくても良いはずだが、__len__len(imgs_list)で取得しようとしたらエラーになったため残している。何か書き方を間違えているかもしれない。

---  ---

def __getitem__(self, idx):
        # indexで指定したディレクトリ以下のファイルを返す

        # 画像をPILとして読み込む
        img_A = Image.open(os.path.join(self.imgs_list[idx], "img_A.png"))  #画像の名前は任意に変更すること
        img_B = Image.open(os.path.join(self.imgs_list[idx], "img_B.png"))

        if self.transform is not None:
            # 前処理がある場合は行う. 普通はtransforms.ToTensor()でTensorに変換してしまう
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return img_A , img_B

---  ---

torch.utils.dataDataset にアクセスする際、配列のインデックス値が引数として送られてくるので、第二引数のidxがそれを受け取る。つまり対象の画像ペアが入っているフォルダはself.img_list[index]となるから、それに画像の名前("img_A.png")を付与してPIL.Image.open()で画像をPILとして読み込む。

---  ---

    def __len__(self):  # 画像ペアの入ってるフォルダの数 = データセット数を返す
        return len(glob.glob(os.path.join(self.img_paths, "*")))

前述の通り、本来ならばlen(imgs_list)で問題ない筈だが、手元の環境ではエラーが出たので再度globを実行してお茶を濁している。

DatasetをDataLoaderに入れて利用

# ** 利用時 **
# データセットを作成する
data_set = PairImgs("./my_img_path", transform=transforms.ToTensor())

---  ---

transformの詳細は参考リンクを参照のこと。

---  ---

# (必要なら)学習用とテスト用にデータセットを分割する
train_size = int( len(data_set) * 0.8 ) # 教師データのサイズ 全体の80%とする
test_size = n_samples - train_size  # テスト用データのサイズ

train_data, test_data = torch.utils.data.random_split(
        data_set,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(0)  # 乱数シードの固定
    )

---  ---

個人的な考えだが、データセットの分割はDataLoaderの作成直前に行った方が楽である。
今回は訓練データとテストデータに分割したが、検証データにも分割したい場合はtorch.utils.data.random_split()の第二引数を [train_size, test_size, valid_size] のような3要素のリストとし、戻り値を3つの変数で受け取れば良い。

---  ---

# 作成したデータセットをデータローダ―に読み込ませる
batch_size = 32    # データローダ―のバッチサイズ 任意に変更のこと
# 学習データはシャッフルON テストデータはシャッフルなし
train_loader = DataLoader(train_data, batch_size=batch_size, 
                          shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False, num_workers=2)

DataLoaderに上記で分割したDatasetを指定して終わり。
後はfor文などで1バッチずつ取り出しながら学習を進めていけばよい。

#参考リンク
pyTorchのtransforms,Datasets,Dataloaderの説明と自作Datasetの作成と使用

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?