3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

複雑に抜取・連結されたデータセットの画像番号を取得する

Last updated at Posted at 2020-05-03

torch.utils.data.Datasetの様式をしたデータセットを、torch.utils.data.Subsetやtorch.utils.data.ConcatDatasetを用いて、任意データを抜き取ったり、データセットを連結して、複雑に抜取・連結されたデータセットがある時、元のデータセットでの画像番号を取得する。

グローバル変数のように番号を格納して用いるので、番号取得の部分だけに関しては並列処理は不可とはなるものの、この部分はまず並列化しないはず。手軽な方法で、突貫で作成したもの。もっとよい方法があればなと。

Libs

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

Datasetの抜取・連結

任意データを抜き取ったり、データセットを連結して、複雑に抜取・連結されたデータセットをまず作成。

# sec: データセット

ds = torchvision.datasets.MNIST(root="trains/pytorch-mnist", train=True, download=True, 
    transform=torchvision.transforms.ToTensor())
print("サイズ:", len(ds), "先頭10個の真値:", [ds[i][1] for i in range(10)])

# sec: Subset, ConcatDatasetを使用

ds_1 = torch.utils.data.Subset(ds, [0, 2, 5]) # 任意位置の抜き取り
print("サイズ:", len(ds_1), "真値:", [ds_1[i][1] for i in range(3)])
ds_2 = torch.utils.data.ConcatDataset([ds_1, ds_1, ds_1]) # データセットを連結
print("サイズ:", len(ds_2), "真値:", [ds_2[i][1] for i in range(9)])
サイズ: 60000 先頭10個の真値: [5, 0, 4, 1, 9, 2, 1, 3, 1, 4]
サイズ: 3 真値: [5, 4, 2]
サイズ: 9 真値: [5, 4, 2, 5, 4, 2, 5, 4, 2]

Datasetの画像番号を取得

以下コードで、複雑に抜取・連結されたデータセットがある時、元のデータセットでの画像番号を取得する。

# sec: データセット内の画像番号を追跡

class AccessPrintDataset(torch.utils.data.Dataset):
    
    def __init__(self, ds):
        self.dataset = ds
        self.cur_index = None # 現在アクセスされた番号
        self.if_print = False
    
    def __getitem__(self, index):
        
        self.cur_index = index # 現在アクセスされた番号
        if self.if_print:
            print("accessed:", index)
            
        return self.dataset[index]
    
    def __len__(self):
        return len(self.dataset)

# sec: データセットを組み替え

ds_a = AccessPrintDataset(ds) # 初めにラップしておく
ds_1 = torch.utils.data.Subset(ds_a, [0, 2, 5]) # 任意位置の抜き取り
ds_2 = torch.utils.data.ConcatDataset([ds_1, ds_1, ds_1]) # データセットを連結

# sec: アクセスした画像番号を表示

ds_a.if_print = True
print("アクセスした画像を表示:")
dummy = [ds_2[i] for i in range(len(ds_2))] # データセットにアクセス
ds_a.if_print = False

# sec: アクセスした画像番号のリストを取得

print("\nアクセスした画像番号のリストを取得:")
i_list = [ds_a.cur_index if ds_2[i] else None for i in range(len(ds_2))]
print("番号:", i_list)

print("真値:", [ds[i][1] for i in i_list])
アクセスした画像を表示:
accessed: 0
accessed: 2
accessed: 5
accessed: 0
accessed: 2
accessed: 5
accessed: 0
accessed: 2
accessed: 5

アクセスした画像番号のリストを取得:
番号: [0, 2, 5, 0, 2, 5, 0, 2, 5]
真値: [5, 4, 2, 5, 4, 2, 5, 4, 2]
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?