LoginSignup
10
1

More than 5 years have passed since last update.

Pytorch でシーケンスデータを順番で読込

Last updated at Posted at 2018-09-10

動機

Pytorch で Seq2Seq のようなモデルを学習する時、学習データの入力順番が大事。LSTM のような RNN 学習機にデータを順番に入力する方法を記述。

読み込んだデータのパスが出力できる ImageFolder

Pytorch の datasets.ImageFolder がデータのパスを出力することができません。こちらに参照し、datasets.ImageFolder を継承してgetitem関数を書き直した上でパス出力ができるようになりました。

class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

データフォルダの順番

以下のような順番でシーケンスデータを読み込みたい。
image.png

一般の方法で読込

data_dir = './pregnant'

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])

image_datasets = {x: ImageFolderWithPaths(os.path.join(data_dir, x), 
                                          transform=data_transforms) for x in ['all']}

data_loaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                               batch_size=batch_size, shuffle=False) for x in ['all']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['all']}

for inputs, _, paths in data_loaders['all']:
    print(paths)
    break

出力結果

'0.jpg', 
'1.jpg', 
'10.jpg', 
'100.jpg', 
'1000.jpg', 
'1001.jpg',
'1002.jpg', 
'1003.jpg', 
'1004.jpg', 
'1005.jpg', 
'1006.jpg', 
'1007.jpg'
...

この出力結果が予想と違うので使えません。

ファイルの名前を指定して読込

こちらを参照してファイルの名前を指定する方法で読み込みます。
dataloader を使わずに、普通のループで読み込んで順番が保証できます。そして torch.stack() を利用して dataloader の batch 単位読込がシミュレーションできます。

from PIL import Image
data_iter = iter(data_loaders['all'])

# 本格
for i in range(1488 - batch_size):  
    imgs = []
    for ii in range(i, i + batch_size):
        path = os.path.join('{}.jpg'.format(ii)) 
        print(path)
        img = data_transforms(Image.open(path))
        imgs.append(img)
    print(len(imgs))
    imgs = torch.stack(imgs)
    print(imgs.size())
    break

# 比較用
for inputs, _, paths in data_loaders['all']:
    print(inputs.size())
    break

出力が以下となります:

0.jpg
1.jpg
2.jpg
3.jpg
4.jpg
5.jpg
6.jpg
7.jpg
8.jpg
...
36
torch.Size([36, 3, 224, 224])
torch.Size([36, 3, 224, 224])

dataloader と同じようなシーケンス読込ができました。

10
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
1