Help us understand the problem. What is going on with this article?

Pytorchでモデル構築するとき、torchsummaryがマジ使える件について

はじめに

自分でモデルを構築していて、いつも全結合層につなぐ前に「あれ、インプットの特徴量っていくつだ?」ってなります。よくprint(model)と打つとモデルの構造は理解できるが、FeatureMapのサイズまでは確認出来ません。そこで便利なのがtorchsummaryというものです。

torchsummaryは何者か?

簡単に言うと、特徴マップのサイズを確認できるものです。

どのようにtorchsummaryを使うか

まずはモデルを作ります

今回は以下の簡単なモデルを作りました。
クラス分類するまでは書いていません。

畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎BN➡︎ReLU➡︎pooling➡︎
畳み込み➡︎GlobalAveragePooling

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):     
    def __init__(self):
        super(SimpleCNN,self).__init__()

        self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(16,32,kernel_size=3,stride=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=1)
        self.gap = nn.AdaptiveMaxPool2d(1)

    def forward(self,x):  
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x) 
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv3(x)
        x = self.gap(x)

        return x

torchsummaryのインストール

pip install torchsummary

torchsummary使い方

from torchsummary import summary
model = SimpleCNN()  
summary(model,(3,224,224)) # summary(model,(channels,H,W))

今回は、画像の入力サイズを224x224を想定して試しています。
他の解像度で試したいときは、HWの値を変更してください。

summaryの出力

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 222, 222]             448
       BatchNorm2d-2         [-1, 16, 222, 222]              32
              ReLU-3         [-1, 16, 222, 222]               0
         MaxPool2d-4         [-1, 16, 111, 111]               0
            Conv2d-5         [-1, 32, 109, 109]           4,640
       BatchNorm2d-6         [-1, 32, 109, 109]              64
              ReLU-7         [-1, 32, 109, 109]               0
         MaxPool2d-8           [-1, 32, 54, 54]               0
            Conv2d-9           [-1, 64, 52, 52]          18,496
AdaptiveMaxPool2d-10             [-1, 64, 1, 1]               0
================================================================
Total params: 23,680
Trainable params: 23,680
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 30.29
Params size (MB): 0.09
Estimated Total Size (MB): 30.95
----------------------------------------------------------------

OutputのShapeが確認で出来るのはかなり便利です。
パラメータの数もカウントしてくれるのもありがたいです。

終わりに

torchsummary便利なので、ぜひ使って見てください。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした