torch.nn.functional.conv2d
torch.nn.functional.conv2d
は畳み込み処理を行う関数です。torch.nn.functional.conv2d
を用いた畳み込み処理は下記のように実装することができます。
import torch
import torch.nn.functional as F
filters = torch.randn(8, 4, 3, 3)
inputs = torch.randn(1, 4, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)
print(filters.shape)
print(inputs.shape)
print(outputs.shape)
・実行結果
torch.Size([8, 4, 3, 3])
torch.Size([1, 4, 5, 5])
torch.Size([1, 8, 5, 5])
この実行結果は、「$5 \times 5$の4チャネル・1サンプルのinputs
に対し、8個のフィルタによって$3 \times 3$の畳み込みを実行することで$5 \times 5$の8チャネル・1サンプルのoutputs
が得られる」と解釈すると良いです。
同様に「$5 \times 5$の5チャネル・2サンプルのinputs
に対し、12個のフィルタによって$3 \times 3$の畳み込みを実行してみます。
filters = torch.randn(12, 5, 3, 3)
inputs = torch.randn(2, 5, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)
print(filters.shape)
print(inputs.shape)
print(outputs.shape)
・実行結果
torch.Size([12, 5, 3, 3])
torch.Size([2, 5, 5, 5])
torch.Size([2, 12, 5, 5])
実行結果より、「$5 \times 5$の12チャネル・2サンプルのoutputs
が得られた」ことが確認できます。注意しておく必要があるのが、チャネル数に対応するfilters
とinputs
の2つ目の数字は統一する必要があるということです。
filters = torch.randn(12, 5, 3, 3)
inputs = torch.randn(2, 6, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)
print(filters.shape)
print(inputs.shape)
print(outputs.shape)
・実行結果
RuntimeError: Given groups=1, weight of size [12, 5, 3, 3], expected input[2, 6, 5, 5] to have 5 channels, but got 6 channels instead
チャネル数をfilters
とinputs
で変えて実行した場合、上記のようにRuntimeError
が出力されることは確認しておくと良いです。
torch.nn.Conv2d
torch.nn.Conv2d
は畳み込み処理を実装したクラスであり、逆伝播処理が用意されているのでDeepLearningの実装にあたって基本的に用いられます。
conv1 = torch.nn.Conv2d(16, 32, 3)
conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
conv3 = torch.nn.Conv2d(16, 32, 3, padding=1, stride=2)
input = torch.randn(20, 16, 50, 100)
output1 = conv1(input)
output2 = conv2(input)
output3 = conv3(input)
print(output1.shape)
print(output2.shape)
print(output3.shape)
・実行結果
torch.Size([20, 32, 48, 98])
torch.Size([20, 32, 50, 100])
torch.Size([20, 32, 25, 50])
上記の実行結果は下記のように解釈すると良いです。
-
padding
が0(特に指定しない場合は0となる)の場合、「(フィルタのカーネルサイズ-1)」分だけ画像のサイズが縦横に小さくなる(output1
) -
padding
に1
が指定されると画像の周辺に1ピクセル追加されるのでカーネルのサイズが3の場合は画像のサイズが小さくならない(output2
) -
stride
が2(デフォルトでは1)の場合、画像のサイズは縦横共に半分となる(output3
)
フィルタのサイズやpadding
、stride
は画像の縦と横で別々に指定することもできるので合わせて抑えておくと良いです。
conv1 = torch.nn.Conv2d(16, 32, 3, padding=1, stride=2)
conv2 = torch.nn.Conv2d(16, 32, (5, 3), stride=(2, 1), padding=(2, 1))
conv3 = torch.nn.Conv2d(16, 32, (5, 3), stride=(2, 1), padding=(2, 1), dilation=(1, 2))
conv4 = torch.nn.Conv2d(16, 32, (5, 3), padding=(2, 1), dilation=(2, 2))
conv5 = torch.nn.Conv2d(16, 32, (5, 3), padding=(2, 1), dilation=(3, 3))
input = torch.randn(20, 16, 100, 50)
output1 = conv1(input)
output2 = conv2(input)
output3 = conv3(input)
output4 = conv4(input)
output5 = conv5(input)
print(output1.shape)
print(output2.shape)
print(output3.shape)
print(output4.shape)
print(output5.shape)
・実行結果
torch.Size([20, 32, 50, 25])
torch.Size([20, 32, 50, 50])
torch.Size([20, 32, 50, 48])
torch.Size([20, 32, 96, 48])
torch.Size([20, 32, 92, 46])
上記の結果は下記のように解釈できます。
-
padding=1
、stride=2
で実行した場合は縦横のサイズがどちらも半分となる(output1
) - $5 \times 3$のフィルタを用いて
stride=(2, 1)
、padding=(2, 1)
で実行した場合、縦のみのサイズが半分となる(output2
) - 2.に対し
dilation=(1, 2)
を引数に与えると横方向に拡張(dilation)処理が行われ、横のサイズが2.の結果から-2となる(output3
) -
stride
が1かつdilation=(2, 2)
の場合、縦方向は-4、横方向は-2となる。この現象は縦方向のカーネルのサイズが5であるので、隙間が4つできることに起因する(output4
) -
stride
が1かつdilation=(3, 3)
の場合、縦方向は-8、横方向は-4となる。この現象は縦方向のカーネルのサイズが5であるので、隙間が8つできることに起因する(output5
)