はじめに
Moving MNIST はシーケンスの予測・再構築を評価するためのテストセットです。手書き数字が画面内を動き回るように作られた画像の集まりです。
左がデータセットで、右が予測の例です。こちらのサイトから引用しました。
データは10,000種類あり、それぞれ20フレームあります。画像のサイズは64 x 64で、二つの数字が映っています。
データの読み込み
こちらのサイトからダウンロードできます。
!curl -o mnist_test_seq.npy http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
データを読み込みます。
import numpy as np
path="./mnist_test_seq.npy"
data = np.load(path)
print(data.shape) # (20, 10000, 64, 64)
ipywidgetsを使ってちゃんと動画になっていることを確認してみます。
%matplotlib inline
import matplotlib.pyplot as plt
from ipywidgets import interact
def f(k):
plt.imshow(data[k][0], 'gray')
plt.show()
interact(f, k=(0,19,1) )
このようなセットが10,000セットあります。
10フレームを使用して10フレームを予測する場合のDatasetの作り方です。
import numpy as np
from torch.utils.data import Dataset
class MovingMnistDataset(Dataset):
def __init__(self, path="./mnist_test_seq.npy"):
self.data = np.load(path)
# (t, N, H, W) -> (N, t, C, H, W)
self.data = self.data.transpose(1, 0, 2, 3)[:, :, None, ...]
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i, :10, ...].astype(np.int32), self.data[i, 10:, ...].astype(np.int32)
dataset = MovingMnistDataset()