畳み込みの挙動を確認。
Oh = {(Ih - Fh + 2D)/S} + 1
(Google Colaboratoryで学ぶ! あたらしい人工知能技術の教科書 機械学習・深層学習・強化学習で学ぶAIの基礎技術 p223より)
Fh,Fw フィルタ 高幅
Oh,Ow 出力 高幅
Ih,Iw 入力 高幅
D パディング幅
S ストライド幅
import torch.nn as nn
class CnnTest(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
padding=(1, 1),
padding_mode='zeros'
)
def forward(self, x):
output = self.conv1(x)
return output
x = torch.rand(1, 3, 224, 224)
A = CnnTest()
output = A(x)
print(output.shape)
出力は↓(式通り)
torch.Size([1, 64, 224, 224])