#はじめに
自分でモデルを構築していて、いつも全結合層につなぐ前に「あれ、インプットの特徴量っていくつだ?」ってなります。よくprint(model)
と打つとモデルの構造は理解できるが、FeatureMapのサイズまでは確認出来ません。そこで便利なのがtorchsummary
というものです。
#torchsummaryは何者か?
簡単に言うと、特徴マップのサイズを確認できる
ものです。
#どのようにtorchsummaryを使うか
###まずはモデルを作ります
今回は以下の簡単なモデルを作りました。
クラス分類するまでは書いていません。
畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎GlobalAveragePooling
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN,self).__init__()
self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d((2,2))
self.conv2 = nn.Conv2d(16,32,kernel_size=3,stride=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=1)
self.gap = nn.AdaptiveMaxPool2d(1)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv3(x)
x = self.gap(x)
return x
###torchsummaryのインストール
pip install torchsummary
###torchsummary使い方
from torchsummary import summary
model = SimpleCNN()
summary(model,(3,224,224)) # summary(model,(channels,H,W))
今回は、画像の入力サイズを224x224を想定して試しています。
他の解像度で試したいときは、H
とW
の値を変更してください。
###summaryの出力
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 222, 222] 448
BatchNorm2d-2 [-1, 16, 222, 222] 32
ReLU-3 [-1, 16, 222, 222] 0
MaxPool2d-4 [-1, 16, 111, 111] 0
Conv2d-5 [-1, 32, 109, 109] 4,640
BatchNorm2d-6 [-1, 32, 109, 109] 64
ReLU-7 [-1, 32, 109, 109] 0
MaxPool2d-8 [-1, 32, 54, 54] 0
Conv2d-9 [-1, 64, 52, 52] 18,496
AdaptiveMaxPool2d-10 [-1, 64, 1, 1] 0
================================================================
Total params: 23,680
Trainable params: 23,680
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 30.29
Params size (MB): 0.09
Estimated Total Size (MB): 30.95
----------------------------------------------------------------
OutputのShapeが確認で出来るのはかなり便利です。
パラメータの数もカウントしてくれるのもありがたいです。
#終わりに
torchsummary便利なので、ぜひ使って見てください。