PyTorchの初学者です。cifar10を用いたPyTorchの公式チュートリアルにあるview
関数の引数計算の方法についてメモを残します。この関数は畳み込み層を全結合層につなげるときに用いられる関数です。
view
関数の引数について
以下のコードはCNNのモデルを定義するものです。チュートリアル内のコードから抜粋しています。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) #in_channels,out_channels,kernel_size
self.pool = nn.MaxPool2d(2, 2) #2x2のカーネル
self.conv2 = nn.Conv2d(6, 16, 5) #in_channels,out_channels,kernel_size
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) #ベクトル化
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def forward(self,x)
において、x = x.view(-1, 16 * 5 * 5)
の第2引数の計算方法についてメモします。
view
関数の第二引数の計算法について
第二引数は畳み込み計算とプーリング計算を行うことで計算できます。畳み込み計算による出力画像の大きさは、以下の公式を使って計算できます。これはTorchの公式ページに載っています。
高さ方向の次元は
$H_{out}=\lfloor\frac{H_{in}+2\times \mathrm{padding} -\mathrm{dilation} \times (\mathrm{kernel_size}-1)-1}{\mathrm{stride}}+1\rfloor$
幅方向の次元は
$W_{out}=\lfloor\frac{W_{in}+2\times \mathrm{padding} - \mathrm{dilation} \times (\mathrm{kernel_size}-1)-1}{\mathrm{stride}}+1\rfloor$
ちなみに、$\mathrm{padding}$, $\mathrm{dilation}$,$\mathrm{stride}$ はデフォルト値がそれぞれ0, 1, 1と設定されています。
このチュートリアルで用いているcifar10は、画像サイズが$3\times 32\times 32$(チャンネル,縦、横)となっています。
最初にconv1
で畳み込み計算をするとき、上の公式を使って出力される特徴マップの大きさを計算できます。
$H_{out} = \frac{32+2\times 0 -1\times (5-1)-1}{1} + 1 = 28$
同様に、 $W_{out}$ も28となります。
よって、conv1
の出力サイズは $28\times 28$ とわかりました。
この結果に足してプーリング処理を実行すると、各軸方向のサイズが半分になります。
よって、pool
の出力サイズは $14\times 14$ になります。
次にconv2
について考えてみます。先と同様に公式を使って計算します。
$H_{out} = \frac{14+2\times 0 -1\times (5-1)-1}{1} + 1 = 10$
$W_{out}$ も同様に10となります。
conv2
の出力サイズが$10\times 10$とわかります。これにプーリング処理を施すと各軸方向のサイズが半分になるので、pool
の出力サイズは$5\times 5$になります。
conv2
のチャンネル方向の値は16なので、ベクトル化するべき画像サイズは$16\times 5\times 5$ になり、view
関数の第二引数と一致します。
計算を楽にするために
モデルを組むごとにいちいち計算するのは大変です。
view
関数の第2引数はベクトル化したときの要素数で、第1引数はバッチサイズです。そのため、view(batch_size, -1)
と書けば自分でいちいち計算せずにpython側で計算してくれます。
ちなみにこの書き方は、view
関数への入力次元が[batch_size, channel, height, width]
である場合、[batch_size, channel*height*width]
へと変形せよという命令になります。
例
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(batch_size, -1) #第一引数を指定、第二引数は-1
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
引数計算用の関数
モデルを組む度に計算するのは面倒です。
PyTorchのNeural Networksのページでは、便利な関数が定義されています。
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
使い方の例。
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
まとめ
PyTorchでCNNモデルを組むときに使うview
関数の引数の計算法について調べました。
参考になれば幸いです。