はじめに
qiitaへの2回目の投稿です。(article2)
前回に引き続き、私がnnablaを使っていた中で「こういう情報がqiitaとかにあったらよかったのに」と思いながらなんとか気合いでnnablaのreferenceとdir()
(pythonの標準関数。引数のメンバ変数・関数を返してくれる)で見つけてきたことについてまとめます。
#1. 要件
・OS: macOS Catalina (バージョン 10.15.1)
・python: 3.5.4
・nnabla: 1.3.0
#2. ネットワークの構築
サンプルのネットワークを下記で定義します。(ここまでは前回同様)
import nnabla as nn
import nnabla.functions as F
# [define network]
x = nn.Variable()
y = F.add_scalar(x, 0.5) # <-- (1)とおく
y = F.mul_scalar(y, 2.0)
単純に $y=(x+0.5)\times2$ という形になってます。
この時、変数yは上の式の結果をもつ変数となっていて、途中の計算であるF.add_scalar(x, 0.5)
の結果を(1)と呼びます。
3. 中間層を削除
前述の(1)を削除し、単に$y=x\times2$とする方法を説明します。
これは、nnablaのreferenceにあったnnabla.Variable
のメンバ変数rewire_on
を使用します。これはreferenceにも(比較的)わかりやすい説明があります。下記で実践しました。
h1 = y.parent.inputs[0] # = (1)
x.rewire_on(h1)
動作確認は下記で行いました。
# [check func for visit]
def get_func_name(f):
print(f.name)
print('--- before ---')
y.visit(get_func_name)
print('')
# [rewire_on]
h1 = y.parent.inputs[0] # = (1)
x.rewire_on(h1)
print('--- after ---')
y.visit(get_func_name)
print('')
出力
--- before ---
AddScalar
MulScalar
--- after ---
MulScalar
###解説
- まず(1)を上記の
h1
として取得しました。これについての詳細な説明は前回を参照してください。 -
x.rewire_on(h1)
でh1
が消えます。正確には、計算グラフ上でh1
がx
で置き換えられるようです。記述はコードからわかりますが、計算グラフを、rewire_on
するノードを始点、終点として入力に繋がる部分グラフと出力へ伝わる部分グラフで分けた際に[入力側の終点].rewire_on([出力側の始点])
となります。(言葉だとわかりづらいからそのうち図を載せたい。) - 確認方法として、前回使用した
visit
を用いて、get_func_name(f)
によって全レイヤーのレイヤー名を表示させています。その結果、rewire_on
の後で(1)の計算をしているAddScalar
が消えていることから、所望の動作が確認できました。
4. 次回予告(?)
rewire_on
を使用すれば、中間層を削除するだけでなく、新たなレイヤーを挟み込みこともできます。これもなんとか自分でできるようになりましたが、「qiitaとかに誰か書いといて欲しかった」と思ったので次回にでも書いてみようかと思います。