LoginSignup
2
2

More than 1 year has passed since last update.

[NNabla]構築済みネットワークの中間層を削除する方法

Last updated at Posted at 2020-01-03

はじめに

 qiitaへの2回目の投稿です。(article2)
 前回に引き続き、私がnnablaを使っていた中で「こういう情報がqiitaとかにあったらよかったのに」と思いながらなんとか気合いでnnablaのreferencedir()(pythonの標準関数。引数のメンバ変数・関数を返してくれる)で見つけてきたことについてまとめます。

1. 要件

・OS: macOS Catalina (バージョン 10.15.1)
・python: 3.5.4
・nnabla: 1.3.0

2. ネットワークの構築

 サンプルのネットワークを下記で定義します。(ここまでは前回同様)

article2_rewire_on.py
import nnabla as nn
import nnabla.functions as F

# [define network]
x = nn.Variable()
y = F.add_scalar(x, 0.5)  # <-- (1)とおく
y = F.mul_scalar(y, 2.0)

単純に $y=(x+0.5)\times2$ という形になってます。
この時、変数yは上の式の結果をもつ変数となっていて、途中の計算であるF.add_scalar(x, 0.5)の結果を(1)と呼びます。

3. 中間層を削除

 前述の(1)を削除し、単に$y=x\times2$とする方法を説明します。
これは、nnablaのreferenceにあったnnabla.Variableのメンバ変数rewire_onを使用します。これはreferenceにも(比較的)わかりやすい説明があります。下記で実践しました。

article2_rewire_on.py
h1 = y.parent.inputs[0]   # = (1)
x.rewire_on(h1)

動作確認は下記で行いました。

article2_rewire_on.py
# [check func for visit]
def get_func_name(f):
    print(f.name)
print('--- before ---')
y.visit(get_func_name)
print('')

# [rewire_on]
h1 = y.parent.inputs[0]   # = (1)
x.rewire_on(h1)

print('--- after ---')
y.visit(get_func_name)
print('')

出力

--- before ---
AddScalar
MulScalar

--- after ---
MulScalar

解説

  • まず(1)を上記のh1として取得しました。これについての詳細な説明は前回を参照してください。
  • x.rewire_on(h1)h1が消えます。正確には、計算グラフ上でh1xで置き換えられるようです。記述はコードからわかりますが、計算グラフを、rewire_onするノードを始点、終点として入力に繋がる部分グラフと出力へ伝わる部分グラフで分けた際に[入力側の終点].rewire_on([出力側の始点])となります。(言葉だとわかりづらいからそのうち図を載せたい。)
  • 確認方法として、前回使用したvisitを用いて、get_func_name(f)によって全レイヤーのレイヤー名を表示させています。その結果、rewire_onの後で(1)の計算をしているAddScalarが消えていることから、所望の動作が確認できました。

4. 次回予告(?)

 rewire_onを使用すれば、中間層を削除するだけでなく、新たなレイヤーを挟み込みこともできます。これもなんとか自分でできるようになりましたが、「qiitaとかに誰か書いといて欲しかった」と思ったので次回にでも書いてみようかと思います。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2