Pytorchのcollate_fnはDataloaderの引数です。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
今回はその挙動と使い方を確認していきたいと思います。
#collate_fnとは
datasetで定義された__getitem__がバッチの形になるとき、まずはそれぞれの要素(画像、ターゲットなど)がリストで固められます。
Pytrochの公式に書かれている通りcollate_fnはそれを操作し、最終的にはtorch.Tensorにする関数です。
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
デフォルトでは、torch.stack()でTensorにするのみなのですが、自作のcollate_fnを使うことで高度なバッチを作ることができます。
#collate_fnを自作
デフォルトの挙動は以下とほとんど同じとなっています。(returnの数は__getitem__によって変わりますが)
batchを引数にとり、stackして返しています。
def collate_fn(batch):
images, targets= list(zip(*batch))
images = torch.stack(images)
targets = torch.stack(targets)
return images, targets
自作のcollate_fnはこの中身を変えればいいわけです。
今回は物体検出のバッチを作成します。
物体検出は基本的に物体の矩形とそのラベルを入力としますが、1枚の画像に複数の矩形があることがあるので、バッチにするときどの画像がどの矩形かを結びつける必要があり、インデックスを付ける必要があります。
[[label, xc, yx, w, h],
[ ],
[ ],...]
# これを下に変える
[[0, label xc, yx, w, h],
[0, ],
[1, ],...]
実装自体はそれほど難しくはありません。
def batch_idx_fn(batch):
images, bboxes = list(zip(*batch))
targets = []
for idx, bbox in enumerate(bboxes):
target = np.zeros((len(bbox), 6))
target[:, 1:] = bbox
target[:, 0] = idx
targets.append(target)
images = torch.stack(images)
targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
return images, targets
実際に使ってみると以下のようになります。
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=batch_idx_fn
)
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]
#おわりに
今回紹介したでインデックスつけるとき以外に、
バッチごとにtargetが変わるときや、
targetがstackできないような数値データではないとき、
同じDatasetを少しだけ変えて使いまわしたいときに使えると思います。