LoginSignup
5

More than 3 years have passed since last update.

Julia / Flux のモデルを木構造として表示する

Last updated at Posted at 2019-12-02

やりたいこと

modeltree.png

最終的にこの図を作りたい。
ちなみに、ノードが重なっているところは現状どうしようもなさそうなので、あまり気にしないでほしい。

木構造のプロットは GraphRecipes.jl で簡単にできる。
Julia の AST や Type tree の表示例もあるので、軽く目を通しておくと役に立つかもしれない。

本記事では、モデルに対応する木構造の生成を主に説明していく。

本記事のコード
https://gist.github.com/Lirimy/6d614e6c073defa50cdf8b019fbadbed

利用するパッケージ

Julia v1.3.0

Flux v0.10.0
Plots v0.28.1
GraphRecipes v0.5.0 #master
AbstractTrees v0.2.1

GraphRecipes v0.4.0 (安定版)では木構造の表示ができない。
最新版のインストールが必要である。

]add GraphRecipes#master

using Flux
using Plots
using GraphRecipes, AbstractTrees

Flux におけるモデル・レイヤーの詳細

Flux のモデルやレイヤーを扱うわけだが、下準備としてそれらの Flux における実装を確認しておこう。
https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl

Chain
struct Chain{T<:Tuple}
    layers::T
    Chain(xs...) = new{typeof(xs)}(xs)
end

モデルは Chain という構造体であり、 layers というタプルをラップしている。
このタプルには Dense などのレイヤーや関数が含まれる。
3行目は、ここでは全く気にしなくてもよい。1

Dense
struct Dense{F,S,T}
    W::S
    b::T
    σ::F
end

レイヤーはいくつか種類があるが、ここでは Dense だけ見ておこう。
単純に、 $ y = \sigma (Wx + b) $ のパラメータをまとめているだけである。
$\sigma$ はいわゆる活性化関数をあらわす変数で、

  • identity (default)
  • relu
  • sigmoid

などが利用できる。

Chain や Dense では Function-like object が定義されており、入力を与えたときの出力を求めるのに利用される。

木構造の生成

この例を参考にして、 AbstractTrees.jl で木構造を作っていく。
https://github.com/JuliaPlots/GraphRecipes.jl#abstracttrees-trees

例と同じように、ノードをペア型で表現することにしよう。
ノードから子ノードを得る関数 AbstractTrees.children を定義しておけば、あとはパッケージの方でうまくやってもらえる。

function AbstractTrees.children(pair::Pair)
    obj = pair.second
    props = propertynames(obj)
    [prop => getproperty(obj, prop) for prop in props]
end

Chain や Dense は構造体なので、含まれる property を走査して、propertyname => property の配列を返している。
Chain の要素 layers はタプルだが、構造体と同じ処理で済むようだ。
3要素を持つタプルに対して propertynames すると、 (1, 2, 3) を返す。

Pretty-printing

AbstractTrees.printnode でノードの表示方法を設定する。
AbstractTrees.print_tree はノードを再帰的に調査して、木構造全体を表示してくれる。

function AbstractTrees.printnode(io::IO, pair::Pair)
    name, obj = pair
    if obj isa Function || obj isa Dense
        print(io, obj)
    else
        print(io, name)
    end
end

# https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
model = Chain(
    Dense(28^2, 32, relu),
    Dense(32, 10),
    softmax)

tree = :Chain => model
print_tree(tree)
Result
Chain
└─ layers
   ├─ Dense(784, 32, relu)
   │  ├─ W
   │  ├─ b
   │  └─ relu
   ├─ Dense(32, 10)
   │  ├─ W
   │  ├─ b
   │  └─ identity
   └─ softmax

木構造の図示

前節の AbstractTrees.printnode が定義されていれば、 GraphRecipes.jl の TreePlot で木構造を図として表示できる。
下のコードで、冒頭に示した図が出力される。

plt = plot(TreePlot(tree),
    nodesize = 0.12,
    nodeshape=:ellipse,
    nodecolor=:lightskyblue1,
    fontsize=12,
    curves=false,
    linewidth=3.0,
    linecolor=:darkseagreen3,
    background_color=:lightyellow,
    background_color_outside=:white,
    title="MNIST multi-layer-perceptron model")
plt |> display
#savefig(plt, "modeltree.png")

Plots.jl オプション一覧
https://docs.juliaplots.org/latest/attributes/

色名一覧
http://juliagraphics.github.io/Colors.jl/stable/namedcolors/


  1. Chain は本来タプルによって初期化されるが、複数引数が渡された場合に、それらをタプルとして改めて初期化する内部コンストラクタである 

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5