torch.utils.data.Datasetの様式をしたデータセットを、torch.utils.data.Subsetやtorch.utils.data.ConcatDatasetを用いて、任意データを抜き取ったり、データセットを連結して、複雑に抜取・連結されたデータセットがある時、元のデータセットでの画像番号を取得する。
グローバル変数のように番号を格納して用いるので、番号取得の部分だけに関しては並列処理は不可とはなるものの、この部分はまず並列化しないはず。手軽な方法で、突貫で作成したもの。もっとよい方法があればなと。
Libs
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
Datasetの抜取・連結
任意データを抜き取ったり、データセットを連結して、複雑に抜取・連結されたデータセットをまず作成。
# sec: データセット
ds = torchvision.datasets.MNIST(root="trains/pytorch-mnist", train=True, download=True,
transform=torchvision.transforms.ToTensor())
print("サイズ:", len(ds), "先頭10個の真値:", [ds[i][1] for i in range(10)])
# sec: Subset, ConcatDatasetを使用
ds_1 = torch.utils.data.Subset(ds, [0, 2, 5]) # 任意位置の抜き取り
print("サイズ:", len(ds_1), "真値:", [ds_1[i][1] for i in range(3)])
ds_2 = torch.utils.data.ConcatDataset([ds_1, ds_1, ds_1]) # データセットを連結
print("サイズ:", len(ds_2), "真値:", [ds_2[i][1] for i in range(9)])
サイズ: 60000 先頭10個の真値: [5, 0, 4, 1, 9, 2, 1, 3, 1, 4]
サイズ: 3 真値: [5, 4, 2]
サイズ: 9 真値: [5, 4, 2, 5, 4, 2, 5, 4, 2]
Datasetの画像番号を取得
以下コードで、複雑に抜取・連結されたデータセットがある時、元のデータセットでの画像番号を取得する。
# sec: データセット内の画像番号を追跡
class AccessPrintDataset(torch.utils.data.Dataset):
def __init__(self, ds):
self.dataset = ds
self.cur_index = None # 現在アクセスされた番号
self.if_print = False
def __getitem__(self, index):
self.cur_index = index # 現在アクセスされた番号
if self.if_print:
print("accessed:", index)
return self.dataset[index]
def __len__(self):
return len(self.dataset)
# sec: データセットを組み替え
ds_a = AccessPrintDataset(ds) # 初めにラップしておく
ds_1 = torch.utils.data.Subset(ds_a, [0, 2, 5]) # 任意位置の抜き取り
ds_2 = torch.utils.data.ConcatDataset([ds_1, ds_1, ds_1]) # データセットを連結
# sec: アクセスした画像番号を表示
ds_a.if_print = True
print("アクセスした画像を表示:")
dummy = [ds_2[i] for i in range(len(ds_2))] # データセットにアクセス
ds_a.if_print = False
# sec: アクセスした画像番号のリストを取得
print("\nアクセスした画像番号のリストを取得:")
i_list = [ds_a.cur_index if ds_2[i] else None for i in range(len(ds_2))]
print("番号:", i_list)
print("真値:", [ds[i][1] for i in i_list])
アクセスした画像を表示:
accessed: 0
accessed: 2
accessed: 5
accessed: 0
accessed: 2
accessed: 5
accessed: 0
accessed: 2
accessed: 5
アクセスした画像番号のリストを取得:
番号: [0, 2, 5, 0, 2, 5, 0, 2, 5]
真値: [5, 4, 2, 5, 4, 2, 5, 4, 2]