3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[NNabla]構築済みネットワークの中間層の出力(variable)を取得する方法

Last updated at Posted at 2020-01-03

#はじめに
これは私のqiita初投稿です。(article1)
sonyが公開したDeep Learning用ライブラリnnablaについての記事です。私がnnablaを使っていた中で「こういう情報がqiitaとかにあったらよかったのに」と思いながらなんとか気合いでnnablaのreferencedir()(pythonの標準関数。引数のメンバ変数・関数を返してくれる)で見つけてきたことについてまとめます。
#1. 要件
・OS: macOS Catalina (バージョン 10.15.1)
・python: 3.5.4
・nnabla: 1.3.0
#2. ネットワークの構築
サンプルのネットワークを下記で定義します。

article1_make_network.py
import nnabla as nn
import nnabla.functions as F

# [define network]
x = nn.Variable()
y = F.add_scalar(x, 0.5)  # <-- (1)とおく
y = F.mul_scalar(y, 2.0)

単純に $y=(x+0.5)\times2$ という形になってます。
この時、変数yは上の式の結果をもつ変数となっていて、途中の計算であるF.add_scalar(x, 0.5)の結果を(1)と呼びます。(1)は最後の行でyが上書きされて見えなくなってます。この時の(1)の変数取得方法を解説します。
#3. 中間層の変数の取得(簡易版)
nnabla referenceによると、yにはparentというメンバ変数があり、名前とリファレンスの説明からしてyを出力したレイヤーを取得できるようです。また、parentは何を持っているかdir(y.parent)でみると、inputsoutputsというメンバ変数がありました。これらを用いれば、前述程度の浅いネットワークであれば下記によって(1)の出力を得られます。

article1_make_network.py
# [get middle variable]
h1 = y.parent.inputs[0]

動作確認は下記で行いました。

article1_make_network.py
# [set value & forward]
x.d.fill(0)
y.forward()
print('y.d = {}'.format(y.d))
print('h1.d = {}'.format(h1.d))

出力

y.d = 1.0
h1.d = 0.5

###解説

  • y.parent.inputsでyを出力したレイヤーの入力変数(variable)をlistとして取得できます。F.mul_scalarは入力変数を1つしかとらないので上記の記述で(1)の出力を取得できます。(scalarはレイヤーのパラメータという扱い(convolutionでいうstrideなどと同じ)で、ネットワーク上の変数ではないのでカウントされません。)レイヤーがconvolutionの時は、inputs[1]でweight, inputs[2]でbiasとなっていました。
  • さらに入力に近いレイヤーの出力を取得したい場合は、次のように取得できます。
h = y.parent.inputs[0].parent.inputs[0]....

しかし、これでは多くのレイヤーを持つネットワークについて中間層を取得する際は現実的ではないため、下記で私のよく使用する一般的な方法を紹介します。

#4. 中間層の変数の取得(一般)
nnabla.Variablereferenceを見ると、ある関数funcを与え、レイヤーを順方向にnnabla._function.Function型のオブジェクトfを引数としてなんか処理func(f)をしてくれそうなvisitというメンバ関数があります(元の文は下記に記載)。これを使用すれば前述のように各レイヤーの出力を取得できます。
(先ほどはnnabla.function.Functionで、visitで扱われるのがnnabla._function.Functionで若干違う気もしましたが、やってみるとう所望の動作をしました。)

visit(self, f)
Visit functions recursively in forward order.

Parameters: f (function) – Function object which takes nnabla._function.Function object as an argument.

下記が具体的なコードになります。
単に関数を定義してvisitに入力するだけではその関数内でしか取得した変数を保持してくれないので、classにしてメンバ関数__call__(self, f)を実行してメンバ変数self.middle_vars_dictに保持させていく形にしています。

article1_make_network.py
from collections import OrderedDict
class get_middle_variables:
    def __init__(self):
        self.middle_vars_dict = OrderedDict()
        self.middle_layer_count_dict = OrderedDict()
    def __call__(self, f):
        if f.name in self.middle_layer_count_dict:
            self.middle_layer_count_dict[f.name] += 1
        else:
            self.middle_layer_count_dict[f.name] = 1
        key = f.name + '_{}'.format(self.middle_layer_count_dict[f.name])
        self.middle_vars_dict[key] = f.outputs[0]

GET_MIDDLE_VARIABLES_CLASS = get_middle_variables()
y.visit(GET_MIDDLE_VARIABLES_CLASS)
middle_vars = GET_MIDDLE_VARIABLES_CLASS.middle_vars_dict

動作確認は下記で行いました。

article1_make_network.py
for key in middle_vars:
    print('{} : {}, .d = {}'.format(key, middle_vars[key], middle_vars[key].d))

出力

AddScalar_1 : <Variable((), need_grad=False) at 0x119fe41d8>, .d = 0.5
MulScalar_1 : <Variable((), need_grad=False) at 0x119fe4188>, .d = -1.0

###解説

  • y.visit(GET_MIDDLE_VARIABLES_CLASS)の部分で、中間層の変数を集めてきてます。
  • 今回visitに与えているのはクラスですが、関数のようにしてクラスを呼び出す(class(引数)の形)と、そのクラス内の__call__と定義されたメンバ関数が呼ばれます。
  • f.nameでそのレイヤー名を取得できるようです。
  • 同じ名前のレイヤーを複数使用してしまうと、単にf.nameをkeyにして保存していたら上書きされてしまうので、同じレイヤーでも何番目のレイヤーなのか区別できるようにself.middle_layer_count_dictでレイヤーの番号を数えて覚えておく形にしました。(今回のケースではレイヤーの被りがないので、なくても問題ないです。)
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?