はじめに
Pytorchを使った画像AIを構築する際、torch.utils.data.DataLoader
によりbatch_size
ごとにまとめられた画像データの呼び出しを行う。
ただ、学習中にどの画像がどれだけの回数使われたのかなど、調べたい場合などがある。
そこで今回、torch.utils.data.DataLoader
から画像であるtorch.tensor
ではなく、クラスのラベル並びに画像のパスを出力する方法について書く。
方法
クラスのラベル並びに画像のパスを取り出すtorch.utils.data.Dataset
を用意し、torch.utils.data.DataLoader
に入れてあげればOK。
今回、Dataset
には、それぞれの画像のパスとラベルが2次元のlist
で格納されているものを渡す場合を考える。
list_data = [['path/to/img_0', 'label_0'],['path/to/img_1', 'label_1'], ...]
Dataset
を実装
Dataset
は以下のように実装する。torch.utils.data.Dataset
のクラスを引き継ぎ、画像のパスとラベルを返す。
ポイントは__getitem__
メソッドの返り値で、何を返すように設定するかというところ。
画像を取り出して学習などに使う場合、PIL.Image.open
から画像を開き、resize
などの処理した画像を返り値として渡すが、それを行うと重くなり動作が遅くなる。以下のDataset
だと画像を扱わない分動作が軽くなるため、高速に色々と調べることが可能。
import torch.utils.data as data
class MyDataset_path_label(data.Dataset):
def __init__(self, list_file, phase='train'):
self.list_file = list_file
self.phase = phase
def __len__(self):
# ファイル数を返す
return len(self.list_file)
def __getitem__(self, index):
# 画像のパスを取得
path_image = self.list_file[index][0]
# ラベルを取得
label_class = self.list_file[index][1]
return path_image, label_class
DataLoader
を実装
DataLoader
は以下のように実装する。上述のMyDataset_path_label
はutils.py
に保存しているものと仮定している。
import torch.utils.data as data
from utils import MyDataset_path_label
dataset = MyDataset_path_label(list_file=list_data)
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
shuffle
はTrue
でも良いです。調べたい対象の環境に合わせて変更します。
結果の確認
batch_size=1
の場合を表示。
無事取り出すことができた!
for path, label in dataloader:
print('{},{}'.foramt(path, label)
### 出力
# 'path/to/img_0','label_0'
# 'path/to/img_1','label_1'
# ...
最後に
今回はtorch.utils.data.DataLoader
から画像ではなく画像のパスとラベルを取り出す方法について書いてみた。
学習の再、iteration
ごとにどの画像やラベルのものが使われているのか、調査するのに便利。
パスやラベル以外にも、様々な属性を出力できるようにアレンジすると、色々な情報を調べることができるため便利。