LoginSignup
32
31

More than 5 years have passed since last update.

CNNの図をPythonで描く

Last updated at Posted at 2017-04-23

ConvNet Drawer(2018/1/4追記)

以下のツールが公開されたため、こちらの使用を推奨します。
畳み込みニューラルネットワークをKeras風に定義するとアーキテクチャを図示してくれるツールを作った

KerasのSequentialモデルのような記法でモデルを定義すると、そのアーキテクチャを良い感じに図示してくれるツールを作りました。言ってしまえばテキストを出力しているだけのツールなので依存ライブラリとかもありません。
https://github.com/yu4u/convnet-drawer

概要

Python+pydot+Graphviz を使い、CNNアーキテクチャの図を描きます。
https://github.com/jettan/tikz_cnn を見て、似たような図をTeXではなくPythonで描きたいと思ったのが動機です。

準備

pydotplusとgraphvizをインストールします。
condaを使用していますが、pipでも大丈夫だと思います(未検証)。

conda install -c conda-forge pydotplus
conda install graphviz

適当なdotファイルを用意しておき、pydotplusで読み込み、画像保存・画像表示します。
(Jupyter上で画像表示しています。適宜編集して下さい。)

drawCNN.py
import pydotplus
from IPython.display import Image

graph = pydotplus.graphviz.graph_from_dot_file('dot/pytorchainer.dot')
graph.write_png('img/pytorchainer.png')
Image(graph.create_png())
pytorchainer.dot
digraph G {

Python [shape=box]
Torch

Chainer -> "Chainer v2"
Chainer -> ChainerMN

Python -> PyTorch
Torch -> PyTorch
Chainer -> PyTorch

PyTorch -> PyTorChainer
"Chainer v2" -> PyTorChainer
ChainerMN -> PyTorChainer

この図はフィクションです。 [shape=plaintext]

}

pytorchainer.png
これで、Pythonからdotファイルを描画する準備は完了です。

なお、dot言語やPyDotPlusの仕様は以下をご参照下さい。
Graphvizとdot言語でグラフを描く方法のまとめ
PyDotPlus API Reference

CNN描画

それでは、CNNアーキテクチャの図を描いていきます。と言っても、dot言語で書いたレイヤー(と矢印)をひたすら追加するだけです。以下、位置調整用のマジックナンバーが乱舞していますが、ご容赦下さい。

drawCNN.py
class CNNDot():
    def __init__(self):
        self.layer_id = 0
        self.arrow_id = 0

    def get_layer_str(self, size, channels, xoffset=0.0, yoffset=0.0, fillcolor='white', caption=''):
        width = size * 0.5
        height = size
        x = xoffset
        y = height * 0.5 + yoffset
        x_caption = x - width * 0.25
        y_caption = -y - 0.7

        layer_str = """
          layer{} [
              shape=polygon, sides=4, skew=-2, orientation=90,
              label="", style=filled, fixedsize=true, fillcolor="{}",
              width={}, height={}, pos="{},{}!"
          ]
        """.format(self.layer_id, fillcolor, width, height, x, y)

        if caption != '':
            layer_str += """
              layer_caption{} [
                  shape=plaintext, label="{}", fixedsize=true, fontsize=24,
                  pos="{},{}!"
              ]
            """.format(self.layer_id, caption, x_caption, y_caption)

        self.layer_id += 1
        return layer_str

    def get_arrow_str(self, xmin, ymin, xmax, ymax):
        arrow_str = """
            arrow{0}_tail [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{1},{2}!"
            ]
            arrow{0}_head [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{3},{4}!"
            ]
            arrow{0}_tail -> arrow{0}_head
        """.format(self.arrow_id, xmin, ymin, xmax, ymax)
        self.arrow_id += 1
        return arrow_str

cnndot = CNNDot()
# layers
graph_data_main = cnndot.get_layer_str(3.0, 0, -1.00, fillcolor='gray') # input
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.00, caption='conv') # encoder begin
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.50)
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.25, caption='conv')
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 2.50, caption='conv')
graph_data_main += cnndot.get_layer_str(2.0, 0, 3.00)
graph_data_main += cnndot.get_layer_str(1.5, 0, 3.75, caption='conv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 4.25)
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.00, caption='conv')
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.50)
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.25, caption='deconv') # decoder begin
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.75)
graph_data_main += cnndot.get_layer_str(1.5, 0, 7.50, caption='deconv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 8.00)
graph_data_main += cnndot.get_layer_str(2.0, 0, 8.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 9.25)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.00)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.50)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.25)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.75)
graph_data_main += cnndot.get_layer_str(3.0, 0, 12.75, fillcolor='#FF8080') # output

# arrows
graph_data_main += cnndot.get_arrow_str(0.50, 3.0*1.2, 11.25-0.22, 3.0*1.2)
graph_data_main += cnndot.get_arrow_str(1.75, 2.5*1.2, 10.00-0.20, 2.5*1.2)
graph_data_main += cnndot.get_arrow_str(3.00, 2.0*1.2, 8.75-0.18, 2.0*1.2)
graph_data_main += cnndot.get_arrow_str(4.25, 1.5*1.2, 7.50-0.16, 1.5*1.2)
graph_data_main += cnndot.get_arrow_str(5.50, 1.0*1.2, 6.25-0.14, 1.0*1.2)

graph_data_setting = 'graph[ layout = neato, size="16,8"]'
graph_data = 'digraph G {{ \n{}\n{}\n }}'.format(graph_data_setting, graph_data_main)
graph = pydotplus.graphviz.graph_from_dot_data(graph_data)

# save and show image
graph.write_png('img/encoder-decoder.png')
Image(graph.create_png())

このコードの場合、以下のような図が表示されます。(各層が薄いのは仕様です。側面を張り込めば直方体も描ける筈。)

encoder-decoder.png

所感

  • 今回のようなシンプルな図であれば良いが、dotで描画を凝るのは厳しいように感じる。
  • フレームワークと連携して自動表示できると良い。

InceptionV3描画(2017/4/30追記)

上記のコードとは異なりますが、Kerasのモデル(InceptionV3)を描画してみました。
直方体はsvgwriteで描画して貼り付けています。

inceptionv3.png

32
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
31