LoginSignup
258
221

More than 5 years have passed since last update.

畳み込みニューラルネットワークをKeras風に定義するとアーキテクチャの図をパワーポイントで保存してくれるツールを作った

Last updated at Posted at 2018-01-03

(2018/01/05 追記)ちょうどpython-pptxを調べていたので、pptx形式で図を保存できるようにし、タイトルも修正しました。

はじめに

論文やスライドで、畳み込みニューラルネットワークのアーキテクチャを良い感じに表示したいときがありますよね?スライドだとオリジナル論文の図の引用でも良いかなという気がしますが、論文の図としては使いたくありません。
ということでKerasのSequentialモデルのような記法でモデルを定義すると、そのアーキテクチャを良い感じに図示してくれるツールを作りました。言ってしまえばテキストを出力しているだけのツールなので依存ライブラリとかもありません。
https://github.com/yu4u/convnet-drawer

ここまで実装するつもりはなかったので綺麗に設計できていませんが、バグ報告や追加機能要望welcomeです!

経緯

元々は、全然違う論文用の図の作成方法を検討していたところ、@gou_koutaki 先生からSVGで出力すればいいじゃんとコメントを頂きました。

Pythonにはsvgwriteというライブラリがあり、SVGが簡単に出力できそうだったのですが、このくらいであれば自分で書いちゃえるかなと色々遊んでいたところ、気づいたら畳み込みニューラルネットワークの図示ツールを作成していました…
最初はただのBOXを描いていたのが、気づけば畳み込み層を作って、プーリング層も実装したら全結合層も欲しくなり、実際のモデルに合わせるためにpaddingも対応していました…何をしているんだろう…

ちなみに、畳み込みニューラルネットワークを図示するような素晴らしいツールとして下記があるのですが、モデルの書き方が直感的ではないのと、個人的には特徴マップをボリュームで表現したかったので利用したことはありませんでした。
https://github.com/gwding/draw_convnet

利用方法

KerasのSequentialモデルの記法のように、畳み込み層やプーリング層を重ねていくだけです。例えばAlexNetは下記のようにモデルを記述します。

from convnet_drawer import Model, Conv2D, MaxPooling2D, Flatten, Dense

model = Model(input_shape=(227, 227, 3))
model.add(Conv2D(96, (11, 11), (4, 4)))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(256, (5, 5), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(384, (3, 3), padding="same"))
model.add(Conv2D(256, (3, 3), padding="same"))
model.add(MaxPooling2D((3, 3), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096))
model.add(Dense(4096))
model.add(Dense(1000))
model.save_fig("example.svg")

上記のスクリプトを実行すると、下記のようなSVGフォーマットの画像が出力されます。ベクタ画像なので拡大しても綺麗です。

image.png

python-pptx対応

元々matplotlib等での出力も想定していたため、モデルをビルドする (model.build()) と、model.feature_mapsmodel.layersに、それぞれ特徴マップとレイヤを図示するための情報が格納されるようになっています。この情報は、下記のようなLineクラスとTextクラスのみで構成されているため、どんな形式へも簡単にエクスポートすることができます(当初SVGを想定していたためデフォルト値が一般的ではない表現ですが)。

class Line:
    def __init__(self, x1, y1, x2, y2, color="black", width=1, dasharray="none"):
        self.x1, self.y1 = x1, y1
        self.x2, self.y2 = x2, y2
        self.color = color
        self.width = width
        self.dasharray = dasharray

class Text:
    def __init__(self, x, y, body, color="black", size=20):
        self.x = x
        self.y = y
        self.body = body
        self.color = color
        self.size = size

ということで、このモデルを入力として、python-pptxを利用してpptxファイルを出力するような関数を作成します。
具体的には、Lineはadd_connectorで、Textはadd_textboxでオブジェクトを作ってあげます。

python-pptxのドキュメントを見ればそこまで難しくはないのですが、線のオブジェクトがMSO_CONNECTORであることと、線の色や破線といったプロパティを変更するためのlineattributeがこのMSO_CONNECTORでは対応しておらず、ちょっとアドホックなコードを書く必要があるところがハマりどころかもしれません。

lineattributeについては下記にIssueがあります。
https://github.com/scanny/python-pptx/issues/312

(2018/08/21追記)上記の line attributeがv0.6.12から対応され、connector.ln = connector.get_or_add_ln() のコードが不要になりました。

テキストに関しては、テキストボックスはサイズ0でテキストの位置を調整するためだけのものにしており、テキストの水平方向のcenterかつbottomの位置が指定できるようにアライメントを定義しています。

その他は実際のコードを見ていただいたほうが何となくわかると思います。

import os
from convnet_drawer import *
from pptx import Presentation
from pptx.shapes.connector import Connector
from pptx.enum.shapes import MSO_CONNECTOR
from pptx.enum.dml import MSO_LINE
from pptx.enum.text import MSO_ANCHOR, PP_ALIGN
from pptx.util import Pt
from pptx.dml.color import RGBColor
from pptx.dml.line import LineFormat


# lineプロパティがない対応
def get_or_add_ln(self):
    return self._element.spPr.get_or_add_ln()


# lineプロパティがない対応
Connector.get_or_add_ln = get_or_add_ln


class MyPresentation:
    def __init__(self):
        self.presentation = Presentation(os.path.join(os.path.dirname(__file__), "template.pptx"))
        self.slide_layout = self.presentation.slide_layouts[6]
        self.slide = self.presentation.slides.add_slide(self.slide_layout)
        self.shapes = self.slide.shapes

    def add_line(self, x1, y1, x2, y2, color, width, dasharray):
        connector = self.shapes.add_connector(MSO_CONNECTOR.STRAIGHT, Pt(x1), Pt(y1), Pt(x2), Pt(y2))
        connector.ln = connector.get_or_add_ln()
        line = LineFormat(connector)
        line.width = width
        line.fill.solid()

        if color == "black":
            line.fill.fore_color.rgb = RGBColor(0, 0, 0)
        elif color == "blue":
            line.fill.fore_color.rgb = RGBColor(0, 0, 255)

        if dasharray != "none":
            line.dash_style = MSO_LINE.DASH

    def add_text(self, x, y, body, color, size):
        textbox = self.shapes.add_textbox(Pt(x), Pt(y), Pt(0), Pt(0))
        textbox.text = body
        text_frame = textbox.text_frame
        text_frame.vertical_anchor = MSO_ANCHOR.BOTTOM
        p = text_frame.paragraphs[0]
        font = p.font
        font.name = 'arial'
        font.size = Pt(size)
        p.alignment = PP_ALIGN.CENTER

    def save_pptx(self, filename):
        self.presentation.save(filename)


def save_model_to_pptx(model, filename):
    model.build()
    presentation = MyPresentation()

    for feature_map in model.feature_maps + model.layers:
        for obj in feature_map.objects:
            if isinstance(obj, Line):
                presentation.add_line(obj.x1, obj.y1, obj.x2, obj.y2, obj.color, obj.width, obj.dasharray)
            elif isinstance(obj, Text):
                presentation.add_text(obj.x, obj.y, obj.body, obj.color, obj.size)

    presentation.save_pptx(filename)

対応レイヤ

一般的なレイヤしか対応していません。Deconvや(Deconvに対応しました)(Sequentialなので当然ですが)ResNetのskip connectionのようなものはありません。そもそも最近のモデルは深すぎて図示しても何が何やらになりそうですね。

Conv2D

Conv2D(filters, kernel_size, strides=(1, 1), padding="valid")
畳み込み層です。filtersにフィルタ数を、kernel_sizeにフィルタのカーネルサイズ(タプル)を、stridesにストライドのサイズ(タプル)を入力します。padding"valid""same"のみ対応しています(実装上は"same"でなければ"valid"になっちゃいます)。
(例)Conv2D(96, (11, 11), (4, 4)))

MaxPooling2D, AveragePooling2D

MaxPooling2D(pool_size=(2, 2), strides=None, padding="valid")
プーリング層です。pool_sizeにプーリングのカーネルサイズ(タプル)を、stridesにストライドのサイズ(タプル)を、paddingにパディングのタイプを入力します。stridesを入力しない場合、pool_sizeと同じ値がセットされます。
(例)MaxPooling2D((3, 3), strides=(2, 2))

GlobalAveragePooling2D

GlobalAveragePooling2D()
Global average poolingです。特徴マップのサイズを1×1にし、flatten(1次元に)します。この後は全結合層のみが追加できます。

Flatten

Flatten()
特徴マップをflattenし、1次元にします。この後は全結合層のみが追加できます。

Dense

Dense(units)
全結合層です。unitsに出力次元数を入力します。
(例)Dense(4096)

実行例

LeNet

image.png

AlexNet
image.png

ZFNet
image.png

VGG16
image.png

AutoEncoder
autoencoder.png

258
221
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
258
221