LoginSignup
15
7

More than 3 years have passed since last update.

PyTorchにおけるcollate_fnのデフォルト挙動のメモ

Last updated at Posted at 2020-05-23

collate_fnとは?

PyTorchでデータセットを読み込むとき、DataLoaderを使われている方も多いと思います。(DataLoaderの使い方については良い記事がたくさんあります。例えばこちらの記事などがわかりやすいです。)

collate_fnDataLoaderインスタンスを作成する際にコンストラクタに与える引数のひとつで、データセットから取り出した個々のデータをミニバッチにまとめる役割があります。
より具体的には、公式のドキュメントに記載されているようにcollate_fnデータセットから取り出されたデータのリストを入力とします。そして、collate_fnの戻り値がDataLoaderから出力されることになります。

そのため、自作データセットからDataLoaderによってデータを読み込む場合は、以下の例のようにcollate_fnを自作することで対応できます。

def simple_collate_fn(list_of_data):
    # ここではそれぞれのデータがD次元のベクトルであると想定しています。
    tensors = [torch.FloatTensor(data) for data in list_of_data]
    # 新しく追加した次元にミニバッチにまとめてN x Dの行列にします。(Nはデータ数)
    batched_tensor = tensor.stack(tensors, dim=0)
    # この戻り値が
    # for batched_tensor in dataloader:
    # のようにDataLoaderから出力されます。
    return batched_tensor

collate_fnのデフォルト挙動

実装をシンプルにするためにも、collate_fnを与えないデフォルトの挙動で対応できる場合には自作のcollate_fnを実装することは避けたいです。

調べてみると、collate_fnはデフォルトでもかなり高機能であり、単にtorch.stack(*, dim=0)のようにテンソルを結合するだけのものでは無いようなので今回は備忘録としてこのデフォルト機能をまとめてみたいと思います。

公式ドキュメント

実はcollate_fnのデフォルトの挙動については公式のドキュメントにしっかり記載されています。

  • It always prepends a new dimension as the batch dimension.
  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.

つまり以下のような機能があるようです。

  • テンソルの最初に新しい次元を追加して、その次元に沿ってバッチ中のデータを結合する
  • numpyのarrayやpythonの数値は自動的にPyTorchのテンソルに変換してくれる
  • データの構造(例えばdict, list, tuple, namedtupleなど)を保持しつつ、その要素をテンソルにバッチ化する

特に3つ目の機能の存在についてはまったく初耳だったので驚きでした。(今まで複数のデータベクトルをそれぞれバッチ化する単純なcollate_fnを実装していた自分が恥ずかしい…)

実装を見てみる

とはいえ実際に実装を見てみないと細かい挙動が理解できないので、実際の実装を見て見たいと思います。

実際読むのが一番早いかと思いますが、今後再び調べる際にまたいちいち実装を読まなくて良いようにざっくりまとめておきます。

バージョン1.5時点の情報です。

型による場合分け

デフォルトのcollate_fnであるdefault_collateは再帰的な処理になっていて、引数であるbatchの最初の要素の型によって処理が場合分けされています。

elem = batch[0]
elem_type = type(elem)

以下ではelemの型による具体的な処理をまとめていきます。

torch.Tensor

batchtorch.Tensorの場合は、単純に次元を最初に1つ増やして結合しています。

return torch.stack(batch, 0)

numpyの型

numpyのndarrayの場合は、テンソル化してからtorch.Tensorの場合と同様に結合しています。

return default_collate([torch.as_tensor(b) for b in batch])

一方numpyのスカラーである場合は現在のbatchはベクトルということになるので、そのままテンソル化しています。

return torch.as_tensor(batch)

float, int, str

この場合もbatchはベクトルということになるので、それぞれ以下のようにテンソル化ないしリストのまま返されます。

# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch

dictなどのcollections.abc.Mappingを継承したクラス

以下のようにキーごとにそれぞれバッチ化処理をしたものがもとのキーの値となって返されます。

return {key: default_collate([d[key] for d in batch]) for key in elem}

namedtuple

この場合も、もとのnamedtupleと同じ属性名を保持しつつ、属性ごとにバッチ化処理を行います。

return elem_type(*(default_collate(samples) for samples in zip(*batch)))

listなどのcollections.abc.Sequenceを継承したクラス

以下のように、要素ごとにバッチ化処理を行います。

transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

具体例

例えば以下のように辞書や文字列を含んでいて複雑な構造のデータセットをデフォルトのcollate_fnで読み込んでみます。

import numpy as np
from torch.utils.data import DataLoader

if __name__=="__main__":
    complex_dataset = [
        [0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
        [1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
    ]
    dataloader = DataLoader(complex_dataset, batch_size=2)
    for batch in dataloader:
        print(batch)

すると、以下のように無事バッチ化されることが確認できます。

[
    tensor([0, 1]),
    ('Bob', 'Tom'),
    {
        'height': tensor([172.5000, 153.1000], dtype=torch.float64),
        'feature': tensor([[1, 2, 3],[3, 2, 1]])
    }
]

ところでpythonのfloatはデフォルトだとtorch.float64に変換されてしまうんですね。普通はnumpy.ndarrayでベクトルとかテンソルを表現するので問題は無いのかと思いますが、知らないと罠にはまりそうです。

15
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
7