前回まで
前回まででPyTorchのDataLoaderとDataSetの動きを理解してきました。
今回は、それを応用して、自分でdatasetを自作してみましょう。
多分にこちらのソースを参考にさせていただきました。
datasetを自作してみよう
前回までの内容でちょっと凝ったことができる気がしてきました。
datasetを自作することで、うまいことデータを返せるようにしてみましょう。
MNISTのデータをペアで返すサンプルを作る
最近流行りのMetric Learningなどでは、画像をペアで作る必要があります。いろいろな方法が提案されていますが、とりあえずちょこっと試すのに良いコードが少ないように感じています。そこで今回は、例題としてdatasetを自作することで手軽にペアを扱えるようにしてみましょう。
PairMnistDatasetクラスを作る
まずはクラスを作ります。TorchのDataSetを継承しておきます。
その上で、コンストラクタではMNISTのdatasetを受け取るようにします。
MetricLearningのPositivePairと、NegativePairは、下記のような関係です。
名称 | 内容 |
---|---|
Positive Pair | 同一ラベル |
Negative Pair | 非同一ラベル |
TrainingデータはShuffleしたいので、コンストラクタではラベルの位置関係を作るだけにしておき、Testデータは先にPairのパターンを作っておけばいいので、インデックスのリストを作成します。
from torch.utils.data import Dataset
class PairMnistDataset(Dataset):
def __init__(self, mnist_dataset, train=True):
self.train = train
self.dataset = mnist_dataset
self.transform = mnist_dataset.transform
if self.train:
self.train_data = self.dataset.train_data
self.train_labels = self.dataset.train_labels
self.train_label_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.train_label_set}
else:
self.test_data = self.dataset.test_data
self.test_labels = self.dataset.test_labels
self.test_label_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.test_label_set}
# シャッフルしないので、先にペアを決めておく
positive_pairs = [[i,
np.random.choice(self.label_to_indices[self.test_labels[i].item()]),
1]
for i in range(0, len(self.test_data), 2)]
negative_pairs = [[i,
np.random.choice(self.label_to_indices[np.random.choice(list(self.test_label_set - set([self.test_labels[i].item()])))]),
0]
for i in range(1, len(self.test_data), 2)]
self.test_pairs = positive_pairs + negative_pairs
__getitem__
を作る
前回の記事で勉強した__getitem__
を作っていきましょう。
indexが渡されたときに、どんなデータをreturnするかを記述すればいいだけです。
def __getitem__(self, index):
if self.train:
target = np.random.randint(0, 2)
# img1,label1は先に決めてしまう
img1, label1 = self.train_data[index], self.train_labels[index].item()
if target == 1:
# positive pair
# ラベルが同じとなるindexを選んでくる処理
siamese_index = index
while siamese_index == index:
siamese_index = np.random.choice(self.label_to_indices[label1])
else:
# negative pair
# labelが異なるindexを選んでくる処理
siamese_label = np.random.choice(list(self.train_label_set - set([label1])))
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
img2 = self.train_data[siamese_index]
else:
img1 = self.test_data[self.test_pairs[index][0]]
img2 = self.test_data[self.test_pairs[index][1]]
target = self.test_pairs[index][2]
img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return (img1, img2), target # metric learningのラベルは同一か否か
def __len__(self):
return len(self.dataset)
mainでdatasetとdataloaderを呼んでみる
あとはここまで作ったものを呼んであげるだけです。
ここまでコードも長く、複雑に見えますが、うまく使いこなすとデータのロードがスムーズにできると思います。
def main():
# 最初はいつものやつ
train_dataset = datasets.MNIST(
'~/dataset/MNIST', # 適宜変更
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_dataset = datasets.MNIST(
'~/dataset/MNIST', # 適宜変更
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
# 自作したdatasetとdataloader
pair_train_dataset = PairMnistDataset(train_dataset, train=True)
pair_train_loader = torch.utils.data.DataLoader(
pair_train_dataset,
batch_size=16
)
pair_test_dataset = PairMnistDataset(test_dataset, train=False)
pair_test_loader = torch.utils.data.DataLoader(
pair_test_dataset,
batch_size=16
)
# 例えばこんなふうに呼べる
for (data1, data2), label in pair_train_loader:
print(data1.shape)
print(data2.shape)
print(label)
結果表示はこちら。ちゃんとペアで返ってきていて、それぞれ同一ラベルか否か、のフラグも返ってきています。
このデータを使えば、気軽にMetricLearningができそうです。
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 28, 28])
tensor([1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1])
まとめ
前回、今回でかなり長くなってしまいましたが、PyTorchのDataLoaderとDataSetの理解に関する記事でした。最近流行りのMetricLearningも、こんなふうにデータを読み出してはいかがでしょうか。