collate_fnとは?
PyTorchでデータセットを読み込むとき、DataLoaderを使われている方も多いと思います。(DataLoaderの使い方については良い記事がたくさんあります。例えばこちらの記事などがわかりやすいです。)
collate_fn
はDataLoader
インスタンスを作成する際にコンストラクタに与える引数のひとつで、データセットから取り出した個々のデータをミニバッチにまとめる役割があります。
より具体的には、公式のドキュメントに記載されているようにcollate_fn
はデータセットから取り出されたデータのリストを入力とします。そして、collate_fn
の戻り値がDataLoader
から出力されることになります。
そのため、自作データセットからDataLoader
によってデータを読み込む場合は、以下の例のようにcollate_fn
を自作することで対応できます。
def simple_collate_fn(list_of_data):
# ここではそれぞれのデータがD次元のベクトルであると想定しています。
tensors = [torch.FloatTensor(data) for data in list_of_data]
# 新しく追加した次元にミニバッチにまとめてN x Dの行列にします。(Nはデータ数)
batched_tensor = tensor.stack(tensors, dim=0)
# この戻り値が
# for batched_tensor in dataloader:
# のようにDataLoaderから出力されます。
return batched_tensor
collate_fnのデフォルト挙動
実装をシンプルにするためにも、collate_fn
を与えないデフォルトの挙動で対応できる場合には自作のcollate_fn
を実装することは避けたいです。
調べてみると、collate_fn
はデフォルトでもかなり高機能であり、単にtorch.stack(*, dim=0)
のようにテンソルを結合するだけのものでは無いようなので今回は備忘録としてこのデフォルト機能をまとめてみたいと思います。
公式ドキュメント
実はcollate_fn
のデフォルトの挙動については公式のドキュメントにしっかり記載されています。
- It always prepends a new dimension as the batch dimension.
- It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
- It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
つまり以下のような機能があるようです。
- テンソルの最初に新しい次元を追加して、その次元に沿ってバッチ中のデータを結合する
- numpyのarrayやpythonの数値は自動的にPyTorchのテンソルに変換してくれる
- データの構造(例えば
dict
,list
,tuple
,namedtuple
など)を保持しつつ、その要素をテンソルにバッチ化する
特に3つ目の機能の存在についてはまったく初耳だったので驚きでした。(今まで複数のデータベクトルをそれぞれバッチ化する単純なcollate_fn
を実装していた自分が恥ずかしい…)
実装を見てみる
とはいえ実際に実装を見てみないと細かい挙動が理解できないので、実際の実装を見て見たいと思います。
実際読むのが一番早いかと思いますが、今後再び調べる際にまたいちいち実装を読まなくて良いようにざっくりまとめておきます。
バージョン1.5時点の情報です。
型による場合分け
デフォルトのcollate_fn
であるdefault_collate
は再帰的な処理になっていて、引数であるbatch
の最初の要素の型によって処理が場合分けされています。
elem = batch[0]
elem_type = type(elem)
以下ではelem
の型による具体的な処理をまとめていきます。
torch.Tensor
batch
がtorch.Tensor
の場合は、単純に次元を最初に1つ増やして結合しています。
return torch.stack(batch, 0)
numpy
の型
numpyのndarray
の場合は、テンソル化してからtorch.Tensor
の場合と同様に結合しています。
return default_collate([torch.as_tensor(b) for b in batch])
一方numpyのスカラーである場合は現在のbatch
はベクトルということになるので、そのままテンソル化しています。
return torch.as_tensor(batch)
float
, int
, str
この場合もbatch
はベクトルということになるので、それぞれ以下のようにテンソル化ないしリストのまま返されます。
# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch
dict
などのcollections.abc.Mapping
を継承したクラス
以下のようにキーごとにそれぞれバッチ化処理をしたものがもとのキーの値となって返されます。
return {key: default_collate([d[key] for d in batch]) for key in elem}
namedtuple
この場合も、もとのnamedtuple
と同じ属性名を保持しつつ、属性ごとにバッチ化処理を行います。
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
list
などのcollections.abc.Sequence
を継承したクラス
以下のように、要素ごとにバッチ化処理を行います。
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
具体例
例えば以下のように辞書や文字列を含んでいて複雑な構造のデータセットをデフォルトのcollate_fn
で読み込んでみます。
import numpy as np
from torch.utils.data import DataLoader
if __name__=="__main__":
complex_dataset = [
[0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
[1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
]
dataloader = DataLoader(complex_dataset, batch_size=2)
for batch in dataloader:
print(batch)
すると、以下のように無事バッチ化されることが確認できます。
[
tensor([0, 1]),
('Bob', 'Tom'),
{
'height': tensor([172.5000, 153.1000], dtype=torch.float64),
'feature': tensor([[1, 2, 3],[3, 2, 1]])
}
]
ところでpythonのfloat
はデフォルトだとtorch.float64
に変換されてしまうんですね。普通はnumpy.ndarray
でベクトルとかテンソルを表現するので問題は無いのかと思いますが、知らないと罠にはまりそうです。