動機
Pytorch で Seq2Seq のようなモデルを学習する時、学習データの入力順番が大事。LSTM のような RNN 学習機にデータを順番に入力する方法を記述。
読み込んだデータのパスが出力できる ImageFolder
Pytorch の datasets.ImageFolder がデータのパスを出力することができません。こちらに参照し、datasets.ImageFolder を継承して__getitem__関数を書き直した上でパス出力ができるようになりました。
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
データフォルダの順番
一般の方法で読込
data_dir = './pregnant'
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])
image_datasets = {x: ImageFolderWithPaths(os.path.join(data_dir, x),
transform=data_transforms) for x in ['all']}
data_loaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size, shuffle=False) for x in ['all']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['all']}
for inputs, _, paths in data_loaders['all']:
print(paths)
break
出力結果
'0.jpg',
'1.jpg',
'10.jpg',
'100.jpg',
'1000.jpg',
'1001.jpg',
'1002.jpg',
'1003.jpg',
'1004.jpg',
'1005.jpg',
'1006.jpg',
'1007.jpg'
...
この出力結果が予想と違うので使えません。
ファイルの名前を指定して読込
こちらを参照してファイルの名前を指定する方法で読み込みます。
dataloader を使わずに、普通のループで読み込んで順番が保証できます。そして torch.stack() を利用して dataloader の batch 単位読込がシミュレーションできます。
from PIL import Image
data_iter = iter(data_loaders['all'])
# 本格
for i in range(1488 - batch_size):
imgs = []
for ii in range(i, i + batch_size):
path = os.path.join('{}.jpg'.format(ii))
print(path)
img = data_transforms(Image.open(path))
imgs.append(img)
print(len(imgs))
imgs = torch.stack(imgs)
print(imgs.size())
break
# 比較用
for inputs, _, paths in data_loaders['all']:
print(inputs.size())
break
出力が以下となります:
0.jpg
1.jpg
2.jpg
3.jpg
4.jpg
5.jpg
6.jpg
7.jpg
8.jpg
...
36
torch.Size([36, 3, 224, 224])
torch.Size([36, 3, 224, 224])
dataloader と同じようなシーケンス読込ができました。