(2018/01/05 追記)ちょうどpython-pptxを調べていたので、pptx形式で図を保存できるようにし、タイトルも修正しました。
はじめに
論文やスライドで、畳み込みニューラルネットワークのアーキテクチャを良い感じに表示したいときがありますよね?スライドだとオリジナル論文の図の引用でも良いかなという気がしますが、論文の図としては使いたくありません。
ということでKerasのSequentialモデルのような記法でモデルを定義すると、そのアーキテクチャを良い感じに図示してくれるツールを作りました。言ってしまえばテキストを出力しているだけのツールなので依存ライブラリとかもありません。
https://github.com/yu4u/convnet-drawer
ここまで実装するつもりはなかったので綺麗に設計できていませんが、バグ報告や追加機能要望welcomeです!
経緯
元々は、全然違う論文用の図の作成方法を検討していたところ、@gou_koutaki 先生からSVGで出力すればいいじゃんとコメントを頂きました。
僕も今、論文の図関係で悩んでいて調べていました。SVGという形式で割と簡単にベクターを作れるので、これを吐きだすプログラムを作ったらいいと思います。https://t.co/w5mxBHmBts 簡易ライブラリ:https://t.co/aHufTweLRf pic.twitter.com/GP7nzKuKDa
— GouKoutaki (@gou_koutaki) 2017年12月9日
Pythonにはsvgwriteというライブラリがあり、SVGが簡単に出力できそうだったのですが、このくらいであれば自分で書いちゃえるかなと色々遊んでいたところ、気づいたら畳み込みニューラルネットワークの図示ツールを作成していました…
最初はただのBOXを描いていたのが、気づけば畳み込み層を作って、プーリング層も実装したら全結合層も欲しくなり、実際のモデルに合わせるためにpaddingも対応していました…何をしているんだろう…
ちなみに、畳み込みニューラルネットワークを図示するような素晴らしいツールとして下記があるのですが、モデルの書き方が直感的ではないのと、個人的には特徴マップをボリュームで表現したかったので利用したことはありませんでした。
https://github.com/gwding/draw_convnet
利用方法
KerasのSequentialモデルの記法のように、畳み込み層やプーリング層を重ねていくだけです。例えばAlexNetは下記のようにモデルを記述します。
from convnet_drawer import Model, Conv2D, MaxPooling2D, Flatten, Dense
model = Model(input_shape=(227, 227, 3))
model.add(Conv2D(96, (11, 11), (4, 4)))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(256, (5, 5), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(256, (3, 3), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096))
model.add(Dense(4096))
model.add(Dense(1000))
model.save_fig("example.svg")
上記のスクリプトを実行すると、下記のようなSVGフォーマットの画像が出力されます。ベクタ画像なので拡大しても綺麗です。
python-pptx対応
元々matplotlib等での出力も想定していたため、モデルをビルドする (model.build()
) と、model.feature_maps
とmodel.layers
に、それぞれ特徴マップとレイヤを図示するための情報が格納されるようになっています。この情報は、下記のようなLineクラスとTextクラスのみで構成されているため、どんな形式へも簡単にエクスポートすることができます(当初SVGを想定していたためデフォルト値が一般的ではない表現ですが)。
class Line:
def __init__(self, x1, y1, x2, y2, color="black", width=1, dasharray="none"):
self.x1, self.y1 = x1, y1
self.x2, self.y2 = x2, y2
self.color = color
self.width = width
self.dasharray = dasharray
class Text:
def __init__(self, x, y, body, color="black", size=20):
self.x = x
self.y = y
self.body = body
self.color = color
self.size = size
ということで、このモデルを入力として、python-pptxを利用してpptxファイルを出力するような関数を作成します。
具体的には、Lineはadd_connector
で、Textはadd_textbox
でオブジェクトを作ってあげます。
python-pptxのドキュメントを見ればそこまで難しくはないのですが、線のオブジェクトがMSO_CONNECTOR
であることと、線の色や破線といったプロパティを変更するためのline
attributeがこのMSO_CONNECTOR
では対応しておらず、ちょっとアドホックなコードを書く必要があるところがハマりどころかもしれません。
line
attributeについては下記にIssueがあります。
https://github.com/scanny/python-pptx/issues/312
(2018/08/21追記)上記の line
attributeがv0.6.12から対応され、connector.ln = connector.get_or_add_ln()
のコードが不要になりました。
テキストに関しては、テキストボックスはサイズ0でテキストの位置を調整するためだけのものにしており、テキストの水平方向のcenterかつbottomの位置が指定できるようにアライメントを定義しています。
その他は実際のコードを見ていただいたほうが何となくわかると思います。
import os
from convnet_drawer import *
from pptx import Presentation
from pptx.shapes.connector import Connector
from pptx.enum.shapes import MSO_CONNECTOR
from pptx.enum.dml import MSO_LINE
from pptx.enum.text import MSO_ANCHOR, PP_ALIGN
from pptx.util import Pt
from pptx.dml.color import RGBColor
from pptx.dml.line import LineFormat
# lineプロパティがない対応
def get_or_add_ln(self):
return self._element.spPr.get_or_add_ln()
# lineプロパティがない対応
Connector.get_or_add_ln = get_or_add_ln
class MyPresentation:
def __init__(self):
self.presentation = Presentation(os.path.join(os.path.dirname(__file__), "template.pptx"))
self.slide_layout = self.presentation.slide_layouts[6]
self.slide = self.presentation.slides.add_slide(self.slide_layout)
self.shapes = self.slide.shapes
def add_line(self, x1, y1, x2, y2, color, width, dasharray):
connector = self.shapes.add_connector(MSO_CONNECTOR.STRAIGHT, Pt(x1), Pt(y1), Pt(x2), Pt(y2))
connector.ln = connector.get_or_add_ln()
line = LineFormat(connector)
line.width = width
line.fill.solid()
if color == "black":
line.fill.fore_color.rgb = RGBColor(0, 0, 0)
elif color == "blue":
line.fill.fore_color.rgb = RGBColor(0, 0, 255)
if dasharray != "none":
line.dash_style = MSO_LINE.DASH
def add_text(self, x, y, body, color, size):
textbox = self.shapes.add_textbox(Pt(x), Pt(y), Pt(0), Pt(0))
textbox.text = body
text_frame = textbox.text_frame
text_frame.vertical_anchor = MSO_ANCHOR.BOTTOM
p = text_frame.paragraphs[0]
font = p.font
font.name = 'arial'
font.size = Pt(size)
p.alignment = PP_ALIGN.CENTER
def save_pptx(self, filename):
self.presentation.save(filename)
def save_model_to_pptx(model, filename):
model.build()
presentation = MyPresentation()
for feature_map in model.feature_maps + model.layers:
for obj in feature_map.objects:
if isinstance(obj, Line):
presentation.add_line(obj.x1, obj.y1, obj.x2, obj.y2, obj.color, obj.width, obj.dasharray)
elif isinstance(obj, Text):
presentation.add_text(obj.x, obj.y, obj.body, obj.color, obj.size)
presentation.save_pptx(filename)
対応レイヤ
一般的なレイヤしか対応していません。Deconvや(Deconvに対応しました)(Sequentialなので当然ですが)ResNetのskip connectionのようなものはありません。そもそも最近のモデルは深すぎて図示しても何が何やらになりそうですね。
Conv2D
Conv2D(filters, kernel_size, strides=(1, 1), padding="valid")
畳み込み層です。filters
にフィルタ数を、kernel_size
にフィルタのカーネルサイズ(タプル)を、strides
にストライドのサイズ(タプル)を入力します。padding
は"valid"
か"same"
のみ対応しています(実装上は"same"
でなければ"valid"
になっちゃいます)。
(例)Conv2D(96, (11, 11), (4, 4)))
MaxPooling2D, AveragePooling2D
MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")
プーリング層です。pool_size
にプーリングのカーネルサイズ(タプル)を、strides
にストライドのサイズ(タプル)を、padding
にパディングのタイプを入力します。strides
を入力しない場合、pool_size
と同じ値がセットされます。
(例)MaxPooling2D((3, 3), strides=(2, 2))
GlobalAveragePooling2D
GlobalAveragePooling2D()
Global average poolingです。特徴マップのサイズを1×1にし、flatten(1次元に)します。この後は全結合層のみが追加できます。
Flatten
Flatten()
特徴マップをflattenし、1次元にします。この後は全結合層のみが追加できます。
Dense
Dense(units)
全結合層です。units
に出力次元数を入力します。
(例)Dense(4096)
実行例
LeNet