8
4

More than 3 years have passed since last update.

【PyTorch】 AdaptiveMaxPool2d, AdaptiveAvgPool2dについて

Last updated at Posted at 2020-07-07

前置き

PyTorchにあるAdaptive系のプーリング。

任意の入力サイズに対して、出力サイズを指定してプーリングを行う。
どのような動きになっているのか、ソースコードを見てみた。

カーネルの求め方

カーネルを以下式で求める。

start_index = floor(output_index * input_size / output_size)
# output_index ... 0~(output_size-1)

floor ... 切り捨て

end_index = ceil((output_index + 1) * input_size / output_size)
# output_index ... 0~(output_size-1)

ceil ... 切り上げ

input_size=(10, 10), output_size=(6,6) の場合

output_index start_index end_index カーネルの範囲
0 0 2 0~1
1 1 4 1~3
2 3 5 3~4
3 5 7 5~6
4 6 9 6~8
5 8 10 8~9

AdaptiveMaxPoolの場合

image.png

import torch
import torch.nn as nn

x = torch.arange(100).view(1, 10, 10).float()

#tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
#         [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
#         [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
#         [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
#         [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
#         [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
#         [60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],
#         [70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],
#         [80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],
#         [90., 91., 92., 93., 94., 95., 96., 97., 98., 99.]]])

nn.AdaptiveMaxPool2d((6,6))(x)

#tensor([[[11., 13., 14., 16., 18., 19.],
#         [31., 33., 34., 36., 38., 39.],
#         [41., 43., 44., 46., 48., 49.],
#         [61., 63., 64., 66., 68., 69.],
#         [81., 83., 84., 86., 88., 89.],
#         [91., 93., 94., 96., 98., 99.]]])
8
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
4