13
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのDataSetとDataLoaderを理解する(2)

Last updated at Posted at 2019-11-07

前回まで

前回まででPyTorchのDataLoaderとDataSetの動きを理解してきました。
今回は、それを応用して、自分でdatasetを自作してみましょう。
多分にこちらのソースを参考にさせていただきました。

datasetを自作してみよう

前回までの内容でちょっと凝ったことができる気がしてきました。
datasetを自作することで、うまいことデータを返せるようにしてみましょう。

MNISTのデータをペアで返すサンプルを作る

最近流行りのMetric Learningなどでは、画像をペアで作る必要があります。いろいろな方法が提案されていますが、とりあえずちょこっと試すのに良いコードが少ないように感じています。そこで今回は、例題としてdatasetを自作することで手軽にペアを扱えるようにしてみましょう。

PairMnistDatasetクラスを作る

まずはクラスを作ります。TorchのDataSetを継承しておきます。
その上で、コンストラクタではMNISTのdatasetを受け取るようにします。
MetricLearningのPositivePairと、NegativePairは、下記のような関係です。

名称 内容
Positive Pair 同一ラベル
Negative Pair 非同一ラベル

TrainingデータはShuffleしたいので、コンストラクタではラベルの位置関係を作るだけにしておき、Testデータは先にPairのパターンを作っておけばいいので、インデックスのリストを作成します。

from torch.utils.data import Dataset

class PairMnistDataset(Dataset):
    def __init__(self, mnist_dataset, train=True):
        self.train = train
        self.dataset = mnist_dataset
        self.transform = mnist_dataset.transform

        if self.train:
            self.train_data = self.dataset.train_data
            self.train_labels = self.dataset.train_labels
            self.train_label_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.train_label_set}
        else:
            self.test_data = self.dataset.test_data
            self.test_labels = self.dataset.test_labels
            self.test_label_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.test_label_set}

            # シャッフルしないので、先にペアを決めておく
            positive_pairs = [[i,
                               np.random.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               np.random.choice(self.label_to_indices[np.random.choice(list(self.test_label_set - set([self.test_labels[i].item()])))]),
                               0]
                              for i in range(1, len(self.test_data), 2)]

            self.test_pairs = positive_pairs + negative_pairs

__getitem__を作る

前回の記事で勉強した__getitem__を作っていきましょう。
indexが渡されたときに、どんなデータをreturnするかを記述すればいいだけです。

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)

            # img1,label1は先に決めてしまう
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                # positive pair
                # ラベルが同じとなるindexを選んでくる処理
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                # negative pair
                # labelが異なるindexを選んでくる処理
                siamese_label = np.random.choice(list(self.train_label_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])

            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return (img1, img2), target  # metric learningのラベルは同一か否か

    def __len__(self):
        return len(self.dataset)

mainでdatasetとdataloaderを呼んでみる

あとはここまで作ったものを呼んであげるだけです。
ここまでコードも長く、複雑に見えますが、うまく使いこなすとデータのロードがスムーズにできると思います。

def main():
    # 最初はいつものやつ
    train_dataset = datasets.MNIST(
        '~/dataset/MNIST',  # 適宜変更
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    test_dataset = datasets.MNIST(
        '~/dataset/MNIST',  # 適宜変更
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    # 自作したdatasetとdataloader
    pair_train_dataset = PairMnistDataset(train_dataset, train=True)
    pair_train_loader = torch.utils.data.DataLoader(
        pair_train_dataset,
        batch_size=16
    )

    pair_test_dataset = PairMnistDataset(test_dataset, train=False)
    pair_test_loader = torch.utils.data.DataLoader(
        pair_test_dataset,
        batch_size=16
    )

    # 例えばこんなふうに呼べる
    for (data1, data2), label in pair_train_loader:
        print(data1.shape)
        print(data2.shape)
        print(label)

結果表示はこちら。ちゃんとペアで返ってきていて、それぞれ同一ラベルか否か、のフラグも返ってきています。
このデータを使えば、気軽にMetricLearningができそうです。

    torch.Size([16, 1, 28, 28])
    torch.Size([16, 1, 28, 28])
    tensor([1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1])

まとめ

前回、今回でかなり長くなってしまいましたが、PyTorchのDataLoaderとDataSetの理解に関する記事でした。最近流行りのMetricLearningも、こんなふうにデータを読み出してはいかがでしょうか。

13
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?