はじめに
Pytorchを使った画像AIを構築する際、torch.utils.data.DataLoaderによりbatch_sizeごとにまとめられた画像データの呼び出しを行う。
ただ、学習中にどの画像がどれだけの回数使われたのかなど、調べたい場合などがある。
そこで今回、torch.utils.data.DataLoaderから画像であるtorch.tensorではなく、クラスのラベル並びに画像のパスを出力する方法について書く。
方法
クラスのラベル並びに画像のパスを取り出すtorch.utils.data.Datasetを用意し、torch.utils.data.DataLoaderに入れてあげればOK。
今回、Datasetには、それぞれの画像のパスとラベルが2次元のlistで格納されているものを渡す場合を考える。
list_data = [['path/to/img_0', 'label_0'],['path/to/img_1', 'label_1'], ...]
Datasetを実装
Datasetは以下のように実装する。torch.utils.data.Datasetのクラスを引き継ぎ、画像のパスとラベルを返す。
ポイントは__getitem__メソッドの返り値で、何を返すように設定するかというところ。
画像を取り出して学習などに使う場合、PIL.Image.openから画像を開き、resizeなどの処理した画像を返り値として渡すが、それを行うと重くなり動作が遅くなる。以下のDatasetだと画像を扱わない分動作が軽くなるため、高速に色々と調べることが可能。
import torch.utils.data as data
class MyDataset_path_label(data.Dataset):
def __init__(self, list_file, phase='train'):
self.list_file = list_file
self.phase = phase
def __len__(self):
# ファイル数を返す
return len(self.list_file)
def __getitem__(self, index):
# 画像のパスを取得
path_image = self.list_file[index][0]
# ラベルを取得
label_class = self.list_file[index][1]
return path_image, label_class
DataLoaderを実装
DataLoaderは以下のように実装する。上述のMyDataset_path_labelはutils.pyに保存しているものと仮定している。
import torch.utils.data as data
from utils import MyDataset_path_label
dataset = MyDataset_path_label(list_file=list_data)
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
shuffleはTrueでも良いです。調べたい対象の環境に合わせて変更します。
結果の確認
batch_size=1の場合を表示。
無事取り出すことができた!
for path, label in dataloader:
print('{},{}'.foramt(path, label)
### 出力
# 'path/to/img_0','label_0'
# 'path/to/img_1','label_1'
# ...
最後に
今回はtorch.utils.data.DataLoaderから画像ではなく画像のパスとラベルを取り出す方法について書いてみた。
学習の再、iterationごとにどの画像やラベルのものが使われているのか、調査するのに便利。
パスやラベル以外にも、様々な属性を出力できるようにアレンジすると、色々な情報を調べることができるため便利。