0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実務でよく使うPyTorchの機能まとめ②|畳み込み処理(torch.nn.functional.conv2dとtorch.nn.Conv2d)

Last updated at Posted at 2025-01-22

torch.nn.functional.conv2d

torch.nn.functional.conv2dは畳み込み処理を行う関数です。torch.nn.functional.conv2dを用いた畳み込み処理は下記のように実装することができます。

import torch
import torch.nn.functional as F


filters = torch.randn(8, 4, 3, 3)
inputs = torch.randn(1, 4, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)

print(filters.shape)
print(inputs.shape)
print(outputs.shape)

・実行結果

torch.Size([8, 4, 3, 3])
torch.Size([1, 4, 5, 5])
torch.Size([1, 8, 5, 5])

この実行結果は、「$5 \times 5$の4チャネル・1サンプルのinputsに対し、8個のフィルタによって$3 \times 3$の畳み込みを実行することで$5 \times 5$の8チャネル・1サンプルのoutputsが得られる」と解釈すると良いです。

同様に「$5 \times 5$の5チャネル・2サンプルのinputsに対し、12個のフィルタによって$3 \times 3$の畳み込みを実行してみます。

filters = torch.randn(12, 5, 3, 3)
inputs = torch.randn(2, 5, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)

print(filters.shape)
print(inputs.shape)
print(outputs.shape)

・実行結果

torch.Size([12, 5, 3, 3])
torch.Size([2, 5, 5, 5])
torch.Size([2, 12, 5, 5])

実行結果より、「$5 \times 5$の12チャネル・2サンプルのoutputsが得られた」ことが確認できます。注意しておく必要があるのが、チャネル数に対応するfiltersinputsの2つ目の数字は統一する必要があるということです。

filters = torch.randn(12, 5, 3, 3)
inputs = torch.randn(2, 6, 5, 5)
outputs = F.conv2d(input=inputs, weight=filters, padding=1)

print(filters.shape)
print(inputs.shape)
print(outputs.shape)

・実行結果

RuntimeError: Given groups=1, weight of size [12, 5, 3, 3], expected input[2, 6, 5, 5] to have 5 channels, but got 6 channels instead

チャネル数をfiltersinputsで変えて実行した場合、上記のようにRuntimeErrorが出力されることは確認しておくと良いです。

torch.nn.Conv2d

torch.nn.Conv2dは畳み込み処理を実装したクラスであり、逆伝播処理が用意されているのでDeepLearningの実装にあたって基本的に用いられます。

conv1 = torch.nn.Conv2d(16, 32, 3)
conv2 = torch.nn.Conv2d(16, 32, 3, padding=1)
conv3 = torch.nn.Conv2d(16, 32, 3, padding=1, stride=2)

input = torch.randn(20, 16, 50, 100)

output1 = conv1(input)
output2 = conv2(input)
output3 = conv3(input)

print(output1.shape)
print(output2.shape)
print(output3.shape)

・実行結果

torch.Size([20, 32, 48, 98])
torch.Size([20, 32, 50, 100])
torch.Size([20, 32, 25, 50])

上記の実行結果は下記のように解釈すると良いです。


  1. paddingが0(特に指定しない場合は0となる)の場合、「(フィルタのカーネルサイズ-1)」分だけ画像のサイズが縦横に小さくなる(output1)
  2. padding1が指定されると画像の周辺に1ピクセル追加されるのでカーネルのサイズが3の場合は画像のサイズが小さくならない(output2)
  3. strideが2(デフォルトでは1)の場合、画像のサイズは縦横共に半分となる(output3)

フィルタのサイズやpaddingstrideは画像の縦と横で別々に指定することもできるので合わせて抑えておくと良いです。

conv1 = torch.nn.Conv2d(16, 32, 3, padding=1, stride=2)
conv2 = torch.nn.Conv2d(16, 32, (5, 3), stride=(2, 1), padding=(2, 1))
conv3 = torch.nn.Conv2d(16, 32, (5, 3), stride=(2, 1), padding=(2, 1), dilation=(1, 2))
conv4 = torch.nn.Conv2d(16, 32, (5, 3), padding=(2, 1), dilation=(2, 2))
conv5 = torch.nn.Conv2d(16, 32, (5, 3), padding=(2, 1), dilation=(3, 3))

input = torch.randn(20, 16, 100, 50)

output1 = conv1(input)
output2 = conv2(input)
output3 = conv3(input)
output4 = conv4(input)
output5 = conv5(input)

print(output1.shape)
print(output2.shape)
print(output3.shape)
print(output4.shape)
print(output5.shape)

・実行結果

torch.Size([20, 32, 50, 25])
torch.Size([20, 32, 50, 50])
torch.Size([20, 32, 50, 48])
torch.Size([20, 32, 96, 48])
torch.Size([20, 32, 92, 46])

上記の結果は下記のように解釈できます。


  1. padding=1stride=2で実行した場合は縦横のサイズがどちらも半分となる(output1)
  2. $5 \times 3$のフィルタを用いてstride=(2, 1)padding=(2, 1)で実行した場合、縦のみのサイズが半分となる(output2)
  3. 2.に対しdilation=(1, 2)を引数に与えると横方向に拡張(dilation)処理が行われ、横のサイズが2.の結果から-2となる(output3)
  4. strideが1かつdilation=(2, 2)の場合、縦方向は-4、横方向は-2となる。この現象は縦方向のカーネルのサイズが5であるので、隙間が4つできることに起因する(output4)
  5. strideが1かつdilation=(3, 3)の場合、縦方向は-8、横方向は-4となる。この現象は縦方向のカーネルのサイズが5であるので、隙間が8つできることに起因する(output5)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?