0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

自分でバッチサンプラーをつくる (PyTorch)

Posted at

自分でバッチサンプラーをつくるだけです。
この記事内のコードは以下のノートにまとめてあります。
https://gist.github.com/CookieBox26/58d97abbe8657a1e217bf3e020e4f93d

参考文献

  1. https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
    • DataLoader に渡す batch_sampler は「returns a batch of indices な Iterable」であればよいようです。なので私はイテレータ型で実装します。
  2. https://utokyo-ipp.github.io/4/4-2.html#for文とイテラブルとイテレータ
    • 「for文によって繰り返すことができるオブジェクトのことを総称して、イテラブル (iterable) と呼びます」とあります。
  3. https://docs.python.org/ja/3.8/library/stdtypes.html#typeiter
    • イテレータ型は自身を返す __iter__() と 次のアイテムを返す __next__() というメソッドを実装すればよいようです。

自分でバッチサンプラーをつくらない場合

自分でバッチサンプラーをつくらなくても DataLoader() にバッチサイズを指定すればバッチを切り出してくれます。 shuffle=True とすればシャッフルされたバッチを切り出してくれます。1エポック回し終わっても、再び次のエポックを回すことができます。2エポック目のシャッフルのされ方はちゃんと1エポック目とは変わっています。

from torch.utils.data import DataLoader
import pandas as pd

class MyDataset:
    """
    DataLoader に渡すためのデータセット型を定義します.
    https://pytorch.org/docs/stable/data.html#map-style-datasets
    """
    def __init__(self, df):
        self.df = df
        self.n_sample = len(df)
    def __getitem__(self, batch_idx):
        return self.df.loc[batch_idx, :].values
    def __len__(self):
        return self.n_sample

# 以下のダミーデータをバッチに切り出してみます.
df = pd.DataFrame({
    'a': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'b': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'c': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
})
dataset = MyDataset(df)
batch_size = 4
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
for i_epoch in range(2):
    print(f'===== epoch {i_epoch} =====')
    for i_batch, data in enumerate(dataloader):
         print(f'----- batch {i_batch} -----')
         print(data)
===== epoch 0 =====
----- batch 0 -----
tensor([[1, 1, 1],
        [5, 5, 5],
        [4, 4, 4],
        [6, 6, 6]])
----- batch 1 -----
tensor([[8, 8, 8],
        [7, 7, 7],
        [2, 2, 2],
        [0, 0, 0]])
----- batch 2 -----
tensor([[10, 10, 10],
        [ 3,  3,  3],
        [ 9,  9,  9]])
===== epoch 1 =====
----- batch 0 -----
tensor([[5, 5, 5],
        [7, 7, 7],
        [2, 2, 2],
        [1, 1, 1]])
----- batch 1 -----
tensor([[ 9,  9,  9],
        [ 8,  8,  8],
        [ 3,  3,  3],
        [10, 10, 10]])
----- batch 2 -----
tensor([[4, 4, 4],
        [0, 0, 0],
        [6, 6, 6]])

自分でバッチサンプラーをつくる (リスト)

自分でバッチサンプラーをつくらなくてもバッチを切り出してくれますが、どうしても自分でバッチサンプラーをつくりたいとします。バッチサンプラーは DataLoader()batch_sampler に渡すものですが、「そのバッチのインデックス列を返すようなイテラブル」であればよいことがわかります [1] (イテラブル: for 文で回せるもの [2] )。

なので極論、「i バッチ目に切り出したいインデックス列」のリストであればこの要件を満たします (以下)。ただし、これだと毎エポック切り出されるバッチが固定になってしまうことがわかります。

dataloader = DataLoader(
    dataset=dataset,
    batch_sampler=[[0, 1, 4, 8], [2, 3, 6, 7], [5, 9, 10]])
for i_epoch in range(2):
    print(f'===== epoch {i_epoch} =====')
    for i_batch, data in enumerate(dataloader):
         print(f'----- batch {i_batch} -----')
         print(data)
===== epoch 0 =====
----- batch 0 -----
tensor([[0, 0, 0],
        [1, 1, 1],
        [4, 4, 4],
        [8, 8, 8]])
----- batch 1 -----
tensor([[2, 2, 2],
        [3, 3, 3],
        [6, 6, 6],
        [7, 7, 7]])
----- batch 2 -----
tensor([[ 5,  5,  5],
        [ 9,  9,  9],
        [10, 10, 10]])
===== epoch 1 =====
----- batch 0 -----
tensor([[0, 0, 0],
        [1, 1, 1],
        [4, 4, 4],
        [8, 8, 8]])
----- batch 1 -----
tensor([[2, 2, 2],
        [3, 3, 3],
        [6, 6, 6],
        [7, 7, 7]])
----- batch 2 -----
tensor([[ 5,  5,  5],
        [ 9,  9,  9],
        [10, 10, 10]])

自分でバッチサンプラーをつくる (イテレータ)

リストだと融通が利かないので自分でイテレータ型を実装することにします。イテレータ型というのは、自身を返す __iter__() と 次のアイテムを返す __next__() というメソッドを実装すればよいようです [3]。

まずはデータを順番にバッチに切り出していくような基本のバッチサンプラーをつくってみます。ポイントとしては、

  • __next__() が次のバッチのインデックス列を返すようにします。
  • __iter__() で現在何バッチ目まで切り出したかをリセットします。イテレータが for 文にかけられるとよびだされるのがこの __iter__() なので、ここでリセットしないと 2 エポック目以降が回りません。

そうすると、1エポック目も2エポック目も想定通りに回せます (以下)。

import numpy as np

class MyBatchSampler:
    """
    データを全くシャッフルせず順にバッチに切り出していく基本のバッチサンプラーです.
    """
    def __init__(self, n_sample, batch_size):
        self.n_sample = n_sample
        self.batch_size = batch_size
        self.n_batch = int(np.ceil(self.n_sample / batch_size))
    def __iter__(self):
        # このバッチサンプラーが使用されるとき,現在何バッチ目かをリセットします.
        self.i_batch = -1
        return self
    def _get_i_batch(self, i_batch):
        # i_batch 番目のバッチに所属するサンプルインデックスのリストを返します.
        indices = [i_batch * self.batch_size + i for i in range(self.batch_size)]
        if i_batch == self.n_batch - 1:
            indices = [i for i in indices if i <= self.n_sample - 1]
        return indices
    def __next__(self):
        self.i_batch += 1
        if self.i_batch >= self.n_batch:
            # n_batch 番目に達したらもうバッチを切り出せません.
            raise StopIteration()
        return self._get_i_batch(self.i_batch)

dataloader = DataLoader(
    dataset=dataset,
    batch_sampler=MyBatchSampler(dataset.n_sample, batch_size))
for i_epoch in range(2):
    print(f'===== epoch {i_epoch} =====')
    for i_batch, data in enumerate(dataloader):
         print(f'----- batch {i_batch} -----')
         print(data)
===== epoch 0 =====
----- batch 0 -----
tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
----- batch 1 -----
tensor([[4, 4, 4],
        [5, 5, 5],
        [6, 6, 6],
        [7, 7, 7]])
----- batch 2 -----
tensor([[ 8,  8,  8],
        [ 9,  9,  9],
        [10, 10, 10]])
===== epoch 1 =====
----- batch 0 -----
tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
----- batch 1 -----
tensor([[4, 4, 4],
        [5, 5, 5],
        [6, 6, 6],
        [7, 7, 7]])
----- batch 2 -----
tensor([[ 8,  8,  8],
        [ 9,  9,  9],
        [10, 10, 10]])

シャッフルされたバッチを切り出せるようにする

データを順番にバッチに切り出していくだけでは使い道が少ないので、シャッフルできるようにします。さっきの基本のバッチサンプラーを継承して、例えば以下のように実装すれば、エポックごとに異なるシャッフルの仕方でバッチを切り出せます。

class MyBatchSamplerShuffle(MyBatchSampler):
    """
    データを全てシャッフルするバッチサンプラーです.
    """
    def __init__(self, n_sample, batch_size):
        super().__init__(n_sample, batch_size)
        self.sample_ids_shuffled = [i for i in range(self.n_sample)]
    def __iter__(self):
        # このバッチサンプラーが使用されるとき,サンプルインデックスの列をかきまぜます.
        # つまり,エポックごとにかきまぜます.
        np.random.shuffle(self.sample_ids_shuffled)
        return super().__iter__()
    def __next__(self):
        # 基底クラスの出力を得てから, かきまぜたサンプルインデックスにマッピングして返します.
        indices = super().__next__()
        return [self.sample_ids_shuffled[i] for i in indices]

dataloader = DataLoader(
    dataset=dataset,
    batch_sampler=MyBatchSamplerShuffle(dataset.n_sample, batch_size))
for i_epoch in range(2):
    print(f'===== epoch {i_epoch} =====')
    for i_batch, data in enumerate(dataloader):
         print(f'----- batch {i_batch} -----')
         print(data)
===== epoch 0 =====
----- batch 0 -----
tensor([[4, 4, 4],
        [3, 3, 3],
        [7, 7, 7],
        [6, 6, 6]])
----- batch 1 -----
tensor([[8, 8, 8],
        [9, 9, 9],
        [2, 2, 2],
        [5, 5, 5]])
----- batch 2 -----
tensor([[ 0,  0,  0],
        [10, 10, 10],
        [ 1,  1,  1]])
===== epoch 1 =====
----- batch 0 -----
tensor([[9, 9, 9],
        [4, 4, 4],
        [5, 5, 5],
        [0, 0, 0]])
----- batch 1 -----
tensor([[ 8,  8,  8],
        [ 1,  1,  1],
        [ 6,  6,  6],
        [10, 10, 10]])
----- batch 2 -----
tensor([[2, 2, 2],
        [3, 3, 3],
        [7, 7, 7]])

エポックを回すごとにバッチサイズを小さくする

シャッフルできるだけでは自作した価値がまだないので、変わったバッチサンプラーをつくってみます。さっきのシャッフルバッチサンプラーを継承して以下のように実装すれば、何エポック目かに応じてバッチサイズが exp(-i_epoch) 倍になります。

class MyBatchSamplerDecaying(MyBatchSamplerShuffle):
    """
    エポックごとにバッチサイズが小さくなっていくバッチサンプラーです.
    """
    def __init__(self, n_sample, batch_size):
        super().__init__(n_sample, batch_size)
        self.batch_size_org = batch_size
        self.i_epoch = -1
    def __iter__(self):
        self.i_epoch += 1
        self.batch_size = int(np.ceil(self.batch_size_org * np.exp(-self.i_epoch)))
        print('只今のバッチサイズ:', self.batch_size)
        self.n_batch = int(np.ceil(self.n_sample / self.batch_size))
        return super().__iter__()

dataloader = DataLoader(
    dataset=dataset,
    batch_sampler=MyBatchSamplerDecaying(dataset.n_sample, batch_size))
for i_epoch in range(2):
    print(f'===== epoch {i_epoch} =====')
    for i_batch, data in enumerate(dataloader):
         print(f'----- batch {i_batch} -----')
         print(data)
===== epoch 0 =====
只今のバッチサイズ: 4
----- batch 0 -----
tensor([[ 9,  9,  9],
        [ 3,  3,  3],
        [ 5,  5,  5],
        [10, 10, 10]])
----- batch 1 -----
tensor([[1, 1, 1],
        [8, 8, 8],
        [4, 4, 4],
        [7, 7, 7]])
----- batch 2 -----
tensor([[6, 6, 6],
        [0, 0, 0],
        [2, 2, 2]])
===== epoch 1 =====
只今のバッチサイズ: 2
----- batch 0 -----
tensor([[5, 5, 5],
        [2, 2, 2]])
----- batch 1 -----
tensor([[9, 9, 9],
        [8, 8, 8]])
----- batch 2 -----
tensor([[4, 4, 4],
        [3, 3, 3]])
----- batch 3 -----
tensor([[6, 6, 6],
        [0, 0, 0]])
----- batch 4 -----
tensor([[7, 7, 7],
        [1, 1, 1]])
----- batch 5 -----
tensor([[10, 10, 10]])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?