はじめに
気がつけばあまり理解せずに使っていたPyTorchのDataLoaderとDataSetです。
少し凝ったことがしたくなったら参考にしていただければ幸いです。
後編はこちら。
PyTorchのExampleの確認
PyTorchを使っていれば、当然DataLoaderを見たことがあると思います。
誰もが機械学習で使うMNISTのPyTorchのExampleでもこんな記述があります。
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=256,
shuffle=True)
あるいはQiitaなどで検索するとこんな書き方も見ると思います。
train_dataset = datasets.MNIST(
'~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=256,
shuffle=True)
そしてその後、For文で回すことでバッチでデータを取得し、学習する様子を見ると思います。
for epoch in epochs:
for img, label in train_loader:
# この中で学習の処理を記載する
口語的に書けば、DataLoaderというのは、あるルールに従って、DataSetで記載の通りにデータを運んできてくれる便利なヤツです。
例えば上記の例では、MNISTのデータがNormalizeされた状態で256個(ミニバッチ)ずつimgとlabelにいい感じに入ってきてくれます。
ではどうやってそれが実現されているのか、中身を見に行きましょう。
torch.utils.data.DataLoaderを見てみる
DataLoaderの実装を見てみます。
クラスになっているのがすぐにわかります。
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
"""
# 省略
詳細は割愛しますが、イテレータとしての情報は少し深く見ていくと下記のような実装が見つかります。
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
この__next__
が呼ばれると、dataが返される仕組みです。
そしてこのdataは、datasetにindexを渡して作られているようです。
この段階ではindexがどう作られるか、datasetがどう呼ばれているか、はそこまで神経質にならなくても大丈夫ですが、せっかくなのでもう一歩奥を見に行きましょう。
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
datasetにindexが渡されていますね。このようにクラスのインスタンスを呼んでいるということは、datasetでは__getitem__
が呼ばれているはずです。(こちらが詳しいです。ではこれを踏まえてdatasetを見に行ってみましょう。
datasets.MNISTを見てみる
MNISTの定義を見に行くとすぐにClassであることがわかります。
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
"""
# 省略
では__getitem__
を見に行きましょう。
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
indexが渡され、それに対応するデータをreturnしている様子がわかりやすく書かれています。なるほどこうやってMNISTのデータを返しているのです。
少し処理を見てみると、PILのImage.fromarrayなんかも書いてあります。つまりこの__getitem__
を工夫して書いてあげれば、自在なデータをリターンすることが可能だということです。
torch.utils.data.DataLoaderをもう1回見てみる
だけどまだわからないことがあります。どんなふうにindexが作られているのでしょうか。そのヒントはここにあります。
if sampler is None: # give default samplers
if self.dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
self.sampler = sampler
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
indexはsamplerを通じて作られているようです。
samplerはデフォルトではshuffleという引数のTrue,Falseによって切り替わっています。例えばshuffle=Falseのときの実装を見てみましょう。
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
ここでいうdata_sourceは、datasetのことです。ここまで来て大体全容が掴めたように見えます。つまり、datasetの長さだけ繰り返すということです。逆に言えばdatasetの中で__len__
という特殊メソッドを用意しておく必要がありそうですね。
datasets.MNISTをもう一度見てみる
ではdatasets.MNISTの__len__
を確認しておきましょう。
def __len__(self):
return len(self.data)
dataの長さを返していますね。MNISTでのdataは60000x28x28のサイズなので、60000が返ることになります。だいぶすっきりしてきました。
次回に続く
記事が長くなって来たので、前編はここまで。
後編ではdatasetの自作を行います。