LoginSignup
45
17

More than 3 years have passed since last update.

Pytorchのcollate_fnを使ってみる

Posted at

Pytorchのcollate_fnはDataloaderの引数です。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

今回はその挙動と使い方を確認していきたいと思います。

collate_fnとは

datasetで定義された__getitem__がバッチの形になるとき、まずはそれぞれの要素(画像、ターゲットなど)がリストで固められます。
Pytrochの公式に書かれている通りcollate_fnはそれを操作し、最終的にはtorch.Tensorにする関数です。

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

デフォルトでは、torch.stack()でTensorにするのみなのですが、自作のcollate_fnを使うことで高度なバッチを作ることができます。

collate_fnを自作

デフォルトの挙動は以下とほとんど同じとなっています。(returnの数は__getitem__によって変わりますが)
batchを引数にとり、stackして返しています。

def collate_fn(batch):
    images, targets= list(zip(*batch))
    images = torch.stack(images)
    targets = torch.stack(targets)
    return images, targets

自作のcollate_fnはこの中身を変えればいいわけです。

今回は物体検出のバッチを作成します。
物体検出は基本的に物体の矩形とそのラベルを入力としますが、1枚の画像に複数の矩形があることがあるので、バッチにするときどの画像がどの矩形かを結びつける必要があり、インデックスを付ける必要があります。

[[label, xc, yx, w, h],
 [                   ],
 [                   ],...]

# これを下に変える

[[0, label xc, yx, w, h],
 [0,                   ],
 [1,                   ],...]

実装自体はそれほど難しくはありません。

def batch_idx_fn(batch):
    images, bboxes = list(zip(*batch))
    targets = []
    for idx, bbox in enumerate(bboxes):
        target = np.zeros((len(bbox), 6))
        target[:, 1:] = bbox
        target[:, 0] = idx
        targets.append(target)
    images = torch.stack(images)
    targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
    return images, targets

実際に使ってみると以下のようになります。

test_data_loader = torch.utils.data.DataLoader(
                       test_dataset, 
                       batch_size=1, 
                       shuffle=False, 
                       collate_fn=batch_idx_fn
                       )
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]

おわりに

今回紹介したでインデックスつけるとき以外に、
バッチごとにtargetが変わるときや、
targetがstackできないような数値データではないとき、
同じDatasetを少しだけ変えて使いまわしたいときに使えると思います。

45
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
45
17