LoginSignup
4
3

More than 1 year has passed since last update.

pytorchのtorchvision.datasetオブジェクトには和が定義されている

Posted at

pytorchのtorchvision.datasetsのオブジェクトには和(加算)が定義されているのを知ったので,簡単な使い方を書いてみます.

gistはこちら

準備

import torch
import torchvision
import torchvision.transforms as transforms

from torchvision.datasets import CIFAR10, STL10
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
print('Use CUDA:', use_cuda)

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), 
                         (0.5, 0.5, 0.5)),
])

足してみる

まずは組み込みのCIFAR10SLT10を足してみます.

  • CIFAR10:サイズ32x32, 学習画像枚数50000(各クラス5000)
  • STL10:サイズ96x96,学習画像枚数5000(各クラス500)
train_CIFAR10 = CIFAR10(root='./data/',
                        train=True,
                        transform=transform,
                        )
train_STL10 = STL10(root='./data',
                    split='train',
                    transform=transform,
                    )

足したデータセットオブジェクトadd_datasetのDataloaderである
train_datasetを作成します.

add_dataset = train_CIFAR10 + train_STL10

batch_size = 2

train_loader = DataLoader(add_dataset,
                          batch_size=batch_size,
                          shuffle=True)

ではこのtrain_datasetで学習ループを回す時のようにサンプルを取得してみます.

for i, (data, label) in enumerate(train_loader):
    print(data.shape, label)
    if i > 20:
        break

torch.Size([2, 3, 64, 64]) tensor([1, 6])
torch.Size([2, 3, 64, 64]) tensor([6, 4])
torch.Size([2, 3, 64, 64]) tensor([4, 4])
torch.Size([2, 3, 64, 64]) tensor([5, 3])
torch.Size([2, 3, 64, 64]) tensor([9, 5])
torch.Size([2, 3, 64, 64]) tensor([1, 9])
torch.Size([2, 3, 64, 64]) tensor([8, 3])
torch.Size([2, 3, 64, 64]) tensor([2, 8])
torch.Size([2, 3, 64, 64]) tensor([9, 3])
torch.Size([2, 3, 64, 64]) tensor([7, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 8])
torch.Size([2, 3, 64, 64]) tensor([4, 8])
torch.Size([2, 3, 64, 64]) tensor([9, 3])
torch.Size([2, 3, 64, 64]) tensor([3, 2])
torch.Size([2, 3, 64, 64]) tensor([4, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 1])
torch.Size([2, 3, 64, 64]) tensor([1, 0])
torch.Size([2, 3, 64, 64]) tensor([2, 5])
torch.Size([2, 3, 64, 64]) tensor([8, 1])
torch.Size([2, 3, 64, 64]) tensor([8, 5])
torch.Size([2, 3, 64, 64]) tensor([0, 0])
torch.Size([2, 3, 64, 64]) tensor([3, 8])

これで,足してできた新しいデータセットオブジェクトでもdataloaderが通常通り動くことが分かりました.
ちなみにtransformでリサイズしているので,2つのデータセットの画像サイズが異なっていても,同じサイズのtensorが取得されています.

しかしこれでは,もとの2つのデータセットのどちらから画像がサンプルされているのか,分かりません.そのために,新しいクラスを作成してみます.

クラスを継承してみる

CIFAR10STL10から継承してサブクラスを作成します.
これは単純に,返り値のタプル(画像,ラベル)に加えて,データセット名の文字列を返すものです.

getitemだけ変更する.
class MyCIFAR10(torchvision.datasets.CIFAR10):

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, 'CIFAR10'

class MySTL10(torchvision.datasets.STL10):

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, 'STL10'

ではデータセットオブジェクトを生成して加算してみます.

コンスタントはサブクラスでも同じ
train_MyCIFAR10 = MyCIFAR10(root='./data/',
                            train=True, 
                            transform=transform,
                            )
train_MySTL10 = MySTL10(root='./data', 
                        split='train',
                        transform=transform
                        )

確認のためにデータ数を出力すると確かに元の通りです.

In [ ]: len(train_MyCIFAR10), len(train_MySTL10)
Out[ ]: (50000, 5000)

では加算してデータローダーを作成します.

add_dataset = train_MyCIFAR10 + train_MySTL10

batch_size = 2

train_loader = DataLoader(add_dataset,
                          batch_size=batch_size,
                          shuffle=True)

ではこのtrain_datasetで学習ループを回す時のようにサンプルを取得してみます.

for i, (data, label, dataset_name) in enumerate(train_loader):
    print(data.shape, label, dataset_name)
    if i > 40:
        break

torch.Size([2, 3, 64, 64]) tensor([4, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 8]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([6, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 3]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 8]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([2, 7]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([0, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([7, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 2]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 3]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 0]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([1, 1]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 9]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([6, 0]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([6, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([3, 5]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 1]) ('CIFAR10', 'STL10')
torch.Size([2, 3, 64, 64]) tensor([3, 4]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([7, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([5, 0]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([8, 8]) ('STL10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([4, 6]) ('CIFAR10', 'CIFAR10')
torch.Size([2, 3, 64, 64]) tensor([9, 7]) ('CIFAR10', 'CIFAR10')

これで,足してできた新しいデータセットオブジェクトでは,サンプル,ラベル,データセット名文字列が取得できることが分かりました.

CIFAR10はSTL10に比べて10倍大きいので,サンプルされる数もその分多くなっています.ちなみにデータローダーをshuffle=Falseとすると,最初はCIFAR10だけサンプルされます.

用途

複数のデータセットからサンプルするような場合(Multi-Domain Learningなど)には有用でしょう.

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3