やりたいこと
最終的にこの図を作りたい。
ちなみに、ノードが重なっているところは現状どうしようもなさそうなので、あまり気にしないでほしい。
木構造のプロットは GraphRecipes.jl で簡単にできる。
Julia の AST や Type tree の表示例もあるので、軽く目を通しておくと役に立つかもしれない。
本記事では、モデルに対応する木構造の生成を主に説明していく。
本記事のコード
https://gist.github.com/Lirimy/6d614e6c073defa50cdf8b019fbadbed
利用するパッケージ
Julia v1.3.0
Flux v0.10.0
Plots v0.28.1
GraphRecipes v0.5.0 #master
AbstractTrees v0.2.1
GraphRecipes v0.4.0 (安定版)では木構造の表示ができない。
最新版のインストールが必要である。
]add GraphRecipes#master
using Flux
using Plots
using GraphRecipes, AbstractTrees
Flux におけるモデル・レイヤーの詳細
Flux のモデルやレイヤーを扱うわけだが、下準備としてそれらの Flux における実装を確認しておこう。
https://github.com/FluxML/Flux.jl/blob/master/src/layers/basic.jl
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
end
モデルは Chain
という構造体であり、 layers
というタプルをラップしている。
このタプルには Dense
などのレイヤーや関数が含まれる。
3行目は、ここでは全く気にしなくてもよい。1
struct Dense{F,S,T}
W::S
b::T
σ::F
end
レイヤーはいくつか種類があるが、ここでは Dense
だけ見ておこう。
単純に、 $ y = \sigma (Wx + b) $ のパラメータをまとめているだけである。
$\sigma$ はいわゆる活性化関数をあらわす変数で、
- identity (default)
- relu
- sigmoid
などが利用できる。
Chain や Dense では Function-like object が定義されており、入力を与えたときの出力を求めるのに利用される。
木構造の生成
この例を参考にして、 AbstractTrees.jl で木構造を作っていく。
https://github.com/JuliaPlots/GraphRecipes.jl#abstracttrees-trees
例と同じように、ノードをペア型で表現することにしよう。
ノードから子ノードを得る関数 AbstractTrees.children
を定義しておけば、あとはパッケージの方でうまくやってもらえる。
function AbstractTrees.children(pair::Pair)
obj = pair.second
props = propertynames(obj)
[prop => getproperty(obj, prop) for prop in props]
end
Chain や Dense は構造体なので、含まれる property を走査して、propertyname => property の配列を返している。
Chain の要素 layers はタプルだが、構造体と同じ処理で済むようだ。
3要素を持つタプルに対して propertynames すると、 (1, 2, 3)
を返す。
Pretty-printing
AbstractTrees.printnode
でノードの表示方法を設定する。
AbstractTrees.print_tree
はノードを再帰的に調査して、木構造全体を表示してくれる。
function AbstractTrees.printnode(io::IO, pair::Pair)
name, obj = pair
if obj isa Function || obj isa Dense
print(io, obj)
else
print(io, name)
end
end
# https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
model = Chain(
Dense(28^2, 32, relu),
Dense(32, 10),
softmax)
tree = :Chain => model
print_tree(tree)
Chain
└─ layers
├─ Dense(784, 32, relu)
│ ├─ W
│ ├─ b
│ └─ relu
├─ Dense(32, 10)
│ ├─ W
│ ├─ b
│ └─ identity
└─ softmax
木構造の図示
前節の AbstractTrees.printnode
が定義されていれば、 GraphRecipes.jl の TreePlot
で木構造を図として表示できる。
下のコードで、冒頭に示した図が出力される。
plt = plot(TreePlot(tree),
nodesize = 0.12,
nodeshape=:ellipse,
nodecolor=:lightskyblue1,
fontsize=12,
curves=false,
linewidth=3.0,
linecolor=:darkseagreen3,
background_color=:lightyellow,
background_color_outside=:white,
title="MNIST multi-layer-perceptron model")
plt |> display
#savefig(plt, "modeltree.png")
Plots.jl オプション一覧
https://docs.juliaplots.org/latest/attributes/
色名一覧
http://juliagraphics.github.io/Colors.jl/stable/namedcolors/
-
Chain は本来タプルによって初期化されるが、複数引数が渡された場合に、それらをタプルとして改めて初期化する内部コンストラクタである ↩