Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
125
Help us understand the problem. What is going on with this article?
@tatsuya11bbs

Pytorchでモデル構築するとき、torchsummaryがマジ使える件について

More than 1 year has passed since last update.

はじめに

自分でモデルを構築していて、いつも全結合層につなぐ前に「あれ、インプットの特徴量っていくつだ?」ってなります。よくprint(model)と打つとモデルの構造は理解できるが、FeatureMapのサイズまでは確認出来ません。そこで便利なのがtorchsummaryというものです。

torchsummaryは何者か?

簡単に言うと、特徴マップのサイズを確認できるものです。

どのようにtorchsummaryを使うか

まずはモデルを作ります

今回は以下の簡単なモデルを作りました。
クラス分類するまでは書いていません。

畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎GlobalAveragePooling


import torch
import torch.nn as nn

class SimpleCNN(nn.Module):     
    def __init__(self):
        super(SimpleCNN,self).__init__()

        self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(16,32,kernel_size=3,stride=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=1)
        self.gap = nn.AdaptiveMaxPool2d(1)

    def forward(self,x):  
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x) 
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        x = self.gap(x)

        return x

torchsummaryのインストール

pip install torchsummary

torchsummary使い方

from torchsummary import summary
model = SimpleCNN()  
summary(model,(3,224,224)) # summary(model,(channels,H,W))

今回は、画像の入力サイズを224x224を想定して試しています。
他の解像度で試したいときは、HWの値を変更してください。

summaryの出力


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 222, 222]             448
       BatchNorm2d-2         [-1, 16, 222, 222]              32
              ReLU-3         [-1, 16, 222, 222]               0
         MaxPool2d-4         [-1, 16, 111, 111]               0
            Conv2d-5         [-1, 32, 109, 109]           4,640
       BatchNorm2d-6         [-1, 32, 109, 109]              64
              ReLU-7         [-1, 32, 109, 109]               0
         MaxPool2d-8           [-1, 32, 54, 54]               0
            Conv2d-9           [-1, 64, 52, 52]          18,496
AdaptiveMaxPool2d-10             [-1, 64, 1, 1]               0
================================================================
Total params: 23,680
Trainable params: 23,680
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 30.29
Params size (MB): 0.09
Estimated Total Size (MB): 30.95
----------------------------------------------------------------

OutputのShapeが確認で出来るのはかなり便利です。
パラメータの数もカウントしてくれるのもありがたいです。

終わりに

torchsummary便利なので、ぜひ使って見てください。

125
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
tatsuya11bbs
Machine Learning / Python / Scala / Image Processing / Keras / Tensorflow / Pytorch / .

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
125
Help us understand the problem. What is going on with this article?