Pytorchのcollate_fnはDataloaderの引数です。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
今回はその挙動と使い方を確認していきたいと思います。
collate_fnとは
datasetで定義された__getitem__がバッチの形になるとき、まずはそれぞれの要素(画像、ターゲットなど)がリストで固められます。
Pytrochの公式に書かれている通りcollate_fnはそれを操作し、最終的にはtorch.Tensorにする関数です。
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])
デフォルトでは、torch.stack()でTensorにするのみなのですが、自作のcollate_fnを使うことで高度なバッチを作ることができます。
collate_fnを自作
デフォルトの挙動は以下とほとんど同じとなっています。(returnの数は__getitem__によって変わりますが)
batchを引数にとり、stackして返しています。
def collate_fn(batch):
    images, targets= list(zip(*batch))
    images = torch.stack(images)
    targets = torch.stack(targets)
    return images, targets
自作のcollate_fnはこの中身を変えればいいわけです。
今回は物体検出のバッチを作成します。
物体検出は基本的に物体の矩形とそのラベルを入力としますが、1枚の画像に複数の矩形があることがあるので、バッチにするときどの画像がどの矩形かを結びつける必要があり、インデックスを付ける必要があります。
[[label, xc, yx, w, h],
 [                   ],
 [                   ],...]
# これを下に変える
[[0, label xc, yx, w, h],
 [0,                   ],
 [1,                   ],...]
実装自体はそれほど難しくはありません。
def batch_idx_fn(batch):
    images, bboxes = list(zip(*batch))
    targets = []
    for idx, bbox in enumerate(bboxes):
        target = np.zeros((len(bbox), 6))
        target[:, 1:] = bbox
        target[:, 0] = idx
        targets.append(target)
    images = torch.stack(images)
    targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
    return images, targets
実際に使ってみると以下のようになります。
test_data_loader = torch.utils.data.DataLoader(
                       test_dataset, 
                       batch_size=1, 
                       shuffle=False, 
                       collate_fn=batch_idx_fn
                       )
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]
おわりに
今回紹介したでインデックスつけるとき以外に、
バッチごとにtargetが変わるときや、
targetがstackできないような数値データではないとき、
同じDatasetを少しだけ変えて使いまわしたいときに使えると思います。