46
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

torchsummaryよりtorchinfoがいいよという話

Last updated at Posted at 2020-08-18

なにこれ

  • torchsummaryとtorch-summaryの話
  • 結論:torchsummaryを使っていた人はtorchinfoに変えよう。

追記(2021.2.5)

  • 名前がtorch-summaryからtorchinfoに変わりました。
  • タイトル、結論、記事末尾のリンクだけ修正しました。

環境

tensorflow: 2.3.0
pytorch: 1.6.0
python: 3.7.2
torchsummary: 1.5.1
torch-summary: 1.4.1

summaryがほしいよね

書いたモデルをデバグする際に、さっと可視化できると非常に便利ですが、PyTorchにはtf.kerasのmodel.summary()がなく、print(model)することになります。

keras_summary.py
# https://keras.io/api/models/model/

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)


if __name__ == "__main__":
    model.summary()
output
$ python keras_summary.py

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3)]               0         
_________________________________________________________________
dense (Dense)                (None, 4)                 16        
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 25        
=================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
_________________________________________________________________
pytorch_print.py

import torch.nn as nn


model = nn.Sequential(
    nn.Linear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 5),
    nn.Softmax(dim=1)
)


if __name__ == "__main__":
    print(model)
output

$ python pytorch_print.py

Model(
  (dense1): Linear(in_features=3, out_features=4, bias=True)
  (relu): ReLU()
  (dense2): Linear(in_features=4, out_features=5, bias=True)
  (softmax): Softmax(dim=1)
)

そこで現れたtorchsummary

issueの流れで颯爽と現れ、この味気ないモデル表示をkerasっぽくして情報をリッチにできるのがtorchsummaryです。

installation
$ pip install torchsummary
pytorch_summary.py

import torch.nn as nn
from torchsummary import summary


model = nn.Sequential(
    nn.Linear(3, 4),
    nn.ReLU(),
    nn.Linear(4, 5),
    nn.Softmax(dim=1)
)


if __name__ == "__main__":
    model = model.to("cuda")
    summary(model, (1,3))
output

$ python pytorch_summary.py

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 1, 4]              16
              ReLU-2                 [-1, 1, 4]               0
            Linear-3                 [-1, 1, 5]              25
           Softmax-4                 [-1, 1, 5]               0
================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------

しかしこのtorchsummary、開発が止まっている模様。
pypiからインストールするとコードが古く、これをしないとmultiple inputsに対応できませんでした。

torch-summaryが更に情報をリッチに

  • torchsummaryがmodelをユーザーがto("cuda")しなければならなかった点を解消
  • 実際のコードを書き換える必要がない
  • 親子関係が見やすくなる
installation
# torchsummaryが機能しなくなります。
$ pip install torch-summary
pytorch_torch-summry
# pipでインストールするだけで出力が変わる
$ python pytorch_summary.py

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1, 5]                --
|    └─Linear: 2-1                       [-1, 1, 4]                16
|    └─ReLU: 2-2                         [-1, 1, 4]                --
|    └─Linear: 2-3                       [-1, 1, 5]                25
|    └─Softmax: 2-4                      [-1, 1, 5]                --
==========================================================================================
Total params: 41
Trainable params: 41
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

この他にも、複数インプットがあるときに構造が見やすくなります。

torchsummary_multiple.py

import torch.nn as nn
from torchsummary import summary


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, 5),
            nn.Softmax(dim=1)
        )

    def forward(self, x, y):
        x = self.model(x)
        y = self.model(y)
        return x, y


if __name__ == "__main__":
    model = Model().cuda()
    summary(model, [(1, 3), (1, 3)])
output
$ python torchsummary_multiple.py

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                 [-1, 1, 4]              16
              ReLU-2                 [-1, 1, 4]               0
            Linear-3                 [-1, 1, 5]              25
           Softmax-4                 [-1, 1, 5]               0
            Linear-5                 [-1, 1, 4]              16
              ReLU-6                 [-1, 1, 4]               0
            Linear-7                 [-1, 1, 5]              25
           Softmax-8                 [-1, 1, 5]               0
================================================================
torch-summary_multiple.py

import torch.nn as nn
from torchsummary import summary


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, 5),
            nn.Softmax(dim=1)
        )

    def forward(self, x, y):
        x = self.model(x)
        y = self.model(y)
        return x, y

if __name__ == "__main__":
    model = Model()
    summary(model, [(1, 3), (1, 3)])
output
$ python torch-summary_multiple.py

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Sequential: 1-1                        [-1, 1, 5]                --
|    └─Linear: 2-1                       [-1, 1, 4]                16
|    └─ReLU: 2-2                         [-1, 1, 4]                --
|    └─Linear: 2-3                       [-1, 1, 5]                25
|    └─Softmax: 2-4                      [-1, 1, 5]                --
├─Sequential: 1-2                        [-1, 1, 5]                (recursive)
|    └─Linear: 2-5                       [-1, 1, 4]                (recursive)
|    └─ReLU: 2-6                         [-1, 1, 4]                --
|    └─Linear: 2-7                       [-1, 1, 5]                (recursive)
|    └─Softmax: 2-8                      [-1, 1, 5]                --
==========================================================================================

実はtorchsummaryXもある

torch-summaryは、torchsummaryとtorchsummaryXの後継と自称しています。
torchsummaryXから乗り換える際は下の2つの変更が必要かと思います。

  • from torchsummaryX => from torchsummary
  • summary()の引数inputは、input.shapeに変更

まとめ

46
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
46
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?