LoginSignup
226
157

More than 3 years have passed since last update.

Pytorchでモデル構築するとき、torchsummaryがマジ使える件について

Last updated at Posted at 2020-01-21

はじめに

自分でモデルを構築していて、いつも全結合層につなぐ前に「あれ、インプットの特徴量っていくつだ?」ってなります。よくprint(model)と打つとモデルの構造は理解できるが、FeatureMapのサイズまでは確認出来ません。そこで便利なのがtorchsummaryというものです。

torchsummaryは何者か?

簡単に言うと、特徴マップのサイズを確認できるものです。

どのようにtorchsummaryを使うか

まずはモデルを作ります

今回は以下の簡単なモデルを作りました。
クラス分類するまでは書いていません。

畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎GlobalAveragePooling


import torch
import torch.nn as nn

class SimpleCNN(nn.Module):     
    def __init__(self):
        super(SimpleCNN,self).__init__()

        self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(16,32,kernel_size=3,stride=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=1)
        self.gap = nn.AdaptiveMaxPool2d(1)

    def forward(self,x):  
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x) 
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        x = self.gap(x)

        return x

torchsummaryのインストール

pip install torchsummary

torchsummary使い方

from torchsummary import summary
model = SimpleCNN()  
summary(model,(3,224,224)) # summary(model,(channels,H,W))

今回は、画像の入力サイズを224x224を想定して試しています。
他の解像度で試したいときは、HWの値を変更してください。

summaryの出力


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 222, 222]             448
       BatchNorm2d-2         [-1, 16, 222, 222]              32
              ReLU-3         [-1, 16, 222, 222]               0
         MaxPool2d-4         [-1, 16, 111, 111]               0
            Conv2d-5         [-1, 32, 109, 109]           4,640
       BatchNorm2d-6         [-1, 32, 109, 109]              64
              ReLU-7         [-1, 32, 109, 109]               0
         MaxPool2d-8           [-1, 32, 54, 54]               0
            Conv2d-9           [-1, 64, 52, 52]          18,496
AdaptiveMaxPool2d-10             [-1, 64, 1, 1]               0
================================================================
Total params: 23,680
Trainable params: 23,680
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 30.29
Params size (MB): 0.09
Estimated Total Size (MB): 30.95
----------------------------------------------------------------

OutputのShapeが確認で出来るのはかなり便利です。
パラメータの数もカウントしてくれるのもありがたいです。

終わりに

torchsummary便利なので、ぜひ使って見てください。

226
157
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
226
157