29
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

torchsummaryよりtorchinfoがいいよという話

なにこれ

  • torchsummaryとtorch-summaryの話
  • 結論:torchsummaryを使っていた人はtorchinfoに変えよう。

追記(2021.2.5)

  • 名前がtorch-summaryからtorchinfoに変わりました。
  • タイトル、結論、記事末尾のリンクだけ修正しました。

環境

tensorflow: 2.3.0
pytorch: 1.6.0
python: 3.7.2
torchsummary: 1.5.1
torch-summary: 1.4.1

summaryがほしいよね

書いたモデルをデバグする際に、さっと可視化できると非常に便利ですが、PyTorchにはtf.kerasのmodel.summary()がなく、print(model)することになります。

keras_summary.py
# https://keras.io/api/models/model/

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)


if __name__ == "__main__":
    model.summary()
output
$ python keras_summary.py

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
dense (Dense)                (None, 4)                 16        
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 25        
=================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
_________________________________________________________________
pytorch_print.py

import torch.nn as nn


model = nn.Sequential(
    nn.Linear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 5),
    nn.Softmax(dim=1)
)


if __name__ == "__main__":
    print(model)
output

$ python pytorch_print.py

Model(
  (dense1): Linear(in_features=3, out_features=4, bias=True)
  (relu): ReLU()
  (dense2): Linear(in_features=4, out_features=5, bias=True)
  (softmax): Softmax(dim=1)
)

そこで現れたtorchsummary

issueの流れで颯爽と現れ、この味気ないモデル表示をkerasっぽくして情報をリッチにできるのがtorchsummaryです。

installation
$ pip install torchsummary
pytorch_summary.py

import torch.nn as nn
from torchsummary import summary


model = nn.Sequential(
    nn.Linear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 5),
    nn.Softmax(dim=1)
)


if __name__ == "__main__":
    model = model.to("cuda")
    summary(model, (1,3))
output

$ python pytorch_summary.py

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 1, 4]              16
              ReLU-2                 [-1, 1, 4]               0
            Linear-3                 [-1, 1, 5]              25
           Softmax-4                 [-1, 1, 5]               0
================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------

しかしこのtorchsummary、開発が止まっている模様。
pypiからインストールするとコードが古く、これをしないとmultiple inputsに対応できませんでした。

torch-summaryが更に情報をリッチに

  • torchsummaryがmodelをユーザーがto("cuda")しなければならなかった点を解消
  • 実際のコードを書き換える必要がない
  • 親子関係が見やすくなる
installation
# torchsummaryが機能しなくなります。
$ pip install torch-summary
pytorch_torch-summry
# pipでインストールするだけで出力が変わる
$ python pytorch_summary.py

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1, 5]                --
|    └─Linear: 2-1                       [-1, 1, 4]                16
|    └─ReLU: 2-2                         [-1, 1, 4]                --
|    └─Linear: 2-3                       [-1, 1, 5]                25
|    └─Softmax: 2-4                      [-1, 1, 5]                --
==========================================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

この他にも、複数インプットがあるときに構造が見やすくなります。

torchsummary_multiple.py

import torch.nn as nn
from torchsummary import summary


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, 5),
            nn.Softmax(dim=1)
        )

    def forward(self, x, y):
        x = self.model(x)
        y = self.model(y)
        return x, y


if __name__ == "__main__":
    model = Model().cuda()
    summary(model, [(1, 3), (1, 3)])
output
$ python torchsummary_multiple.py

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 1, 4]              16
              ReLU-2                 [-1, 1, 4]               0
            Linear-3                 [-1, 1, 5]              25
           Softmax-4                 [-1, 1, 5]               0
            Linear-5                 [-1, 1, 4]              16
              ReLU-6                 [-1, 1, 4]               0
            Linear-7                 [-1, 1, 5]              25
           Softmax-8                 [-1, 1, 5]               0
================================================================
torch-summary_multiple.py

import torch.nn as nn
from torchsummary import summary


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, 5),
            nn.Softmax(dim=1)
        )

    def forward(self, x, y):
        x = self.model(x)
        y = self.model(y)
        return x, y

if __name__ == "__main__":
    model = Model()
    summary(model, [(1, 3), (1, 3)])
output
$ python torch-summary_multiple.py

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1, 5]                --
|    └─Linear: 2-1                       [-1, 1, 4]                16
|    └─ReLU: 2-2                         [-1, 1, 4]                --
|    └─Linear: 2-3                       [-1, 1, 5]                25
|    └─Softmax: 2-4                      [-1, 1, 5]                --
├─Sequential: 1-2                        [-1, 1, 5]                (recursive)
|    └─Linear: 2-5                       [-1, 1, 4]                (recursive)
|    └─ReLU: 2-6                         [-1, 1, 4]                --
|    └─Linear: 2-7                       [-1, 1, 5]                (recursive)
|    └─Softmax: 2-8                      [-1, 1, 5]                --
==========================================================================================

実はtorchsummaryXもある

torch-summaryは、torchsummaryとtorchsummaryXの後継と自称しています。
torchsummaryXから乗り換える際は下の2つの変更が必要かと思います。
- from torchsummaryX => from torchsummary
- summary()の引数inputは、input.shapeに変更

まとめ

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
29
Help us understand the problem. What are the problem?