pytorchのtorchvision.datasets
のオブジェクトには和(加算)が定義されているのを知ったので,簡単な使い方を書いてみます.
準備
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, STL10
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
print('Use CUDA:', use_cuda)
transform = transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
])
足してみる
まずは組み込みのCIFAR10
とSLT10
を足してみます.
train_CIFAR10 = CIFAR10(root='./data/',
train=True,
transform=transform,
)
train_STL10 = STL10(root='./data',
split='train',
transform=transform,
)
足したデータセットオブジェクトadd_dataset
のDataloaderである
train_dataset
を作成します.
add_dataset = train_CIFAR10 + train_STL10
batch_size = 2
train_loader = DataLoader(add_dataset,
batch_size=batch_size,
shuffle=True)
ではこのtrain_dataset
で学習ループを回す時のようにサンプルを取得してみます.
for i, (data, label) in enumerate(train_loader):
print(data.shape, label)
if i > 20:
break
torch.Size([2, 3, 64, 64]) tensor([1, 6])
torch.Size([2, 3, 64, 64]) tensor([6, 4])
torch.Size([2, 3, 64, 64]) tensor([4, 4])
torch.Size([2, 3, 64, 64]) tensor([5, 3])
torch.Size([2, 3, 64, 64]) tensor([9, 5])
torch.Size([2, 3, 64, 64]) tensor([1, 9])
torch.Size([2, 3, 64, 64]) tensor([8, 3])
torch.Size([2, 3, 64, 64]) tensor([2, 8])
torch.Size([2, 3, 64, 64]) tensor([9, 3])
torch.Size([2, 3, 64, 64]) tensor([7, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 8])
torch.Size([2, 3, 64, 64]) tensor([4, 8])
torch.Size([2, 3, 64, 64]) tensor([9, 3])
torch.Size([2, 3, 64, 64]) tensor([3, 2])
torch.Size([2, 3, 64, 64]) tensor([4, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 1])
torch.Size([2, 3, 64, 64]) tensor([1, 0])
torch.Size([2, 3, 64, 64]) tensor([2, 5])
torch.Size([2, 3, 64, 64]) tensor([8, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 5])
torch.Size([2, 3, 64, 64]) tensor([0, 0])
torch.Size([2, 3, 64, 64]) tensor([3, 8])
これで,足してできた新しいデータセットオブジェクトでもdataloaderが通常通り動くことが分かりました.
ちなみにtransform
でリサイズしているので,2つのデータセットの画像サイズが異なっていても,同じサイズのtensorが取得されています.
しかしこれでは,もとの2つのデータセットのどちらから画像がサンプルされているのか,分かりません.そのために,新しいクラスを作成してみます.
クラスを継承してみる
CIFAR10
とSTL10
から継承してサブクラスを作成します.
これは単純に,返り値のタプル(画像,ラベル)に加えて,データセット名の文字列を返すものです.
class MyCIFAR10(torchvision.datasets.CIFAR10):
def __getitem__(self, index):
img, target = super().__getitem__(index)
return img, target, 'CIFAR10'
class MySTL10(torchvision.datasets.STL10):
def __getitem__(self, index):
img, target = super().__getitem__(index)
return img, target, 'STL10'
ではデータセットオブジェクトを生成して加算してみます.
train_MyCIFAR10 = MyCIFAR10(root='./data/',
train=True,
transform=transform,
)
train_MySTL10 = MySTL10(root='./data',
split='train',
transform=transform
)
確認のためにデータ数を出力すると確かに元の通りです.
In [ ]: len(train_MyCIFAR10), len(train_MySTL10)
Out[ ]: (50000, 5000)
では加算してデータローダーを作成します.
add_dataset = train_MyCIFAR10 + train_MySTL10
batch_size = 2
train_loader = DataLoader(add_dataset,
batch_size=batch_size,
shuffle=True)
ではこのtrain_dataset
で学習ループを回す時のようにサンプルを取得してみます.
for i, (data, label, dataset_name) in enumerate(train_loader):
print(data.shape, label, dataset_name)
if i > 40:
break
torch.Size([2, 3, 64, 64]) tensor([4, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 8]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([6, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 3]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 8]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 3]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 0]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([6, 0]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([6, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 1]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 0]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 8]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'CIFAR10')
これで,足してできた新しいデータセットオブジェクトでは,サンプル,ラベル,データセット名文字列が取得できることが分かりました.
CIFAR10はSTL10に比べて10倍大きいので,サンプルされる数もその分多くなっています.ちなみにデータローダーをshuffle=False
とすると,最初はCIFAR10だけサンプルされます.
用途
複数のデータセットからサンプルするような場合(Multi-Domain Learningなど)には有用でしょう.