Python
機械学習
python3
TensorFlow

tf.custom_gradientで自作逆伝播

はじめに

tf.custom_gradientを使えば簡単に自作の逆伝播ができることを知ったので紹介します。

逆伝播

機械学習民には説明が不要ですが、deep neural networkはいくつもの関数(層)をつなぎ合わせて推論する構造をしています。この層を訓練によっていい感じに調整していくのが深層学習です。
訓練時には、次のように、出力側から入力側へパラメータの更新方向を教えていきます。これを逆伝播(back propagation)といいます。
backprop.png

通常、逆伝播の計算はほぼワンパターンでライブラリがいい感じにやってくれます。
この記事は、TensorFlowで逆伝播をいじって遊ぶ方法を紹介します。

tf.custom_gradient

自作の逆伝播を簡単に行うAPIとしてtf.custom_gradientが提供されています。これは次のようにデコレーターとして使います。

import tensorflow as tf


@tf.custom_gradient
def forward(x):
    y = x  # f(x) = xという関数の逆伝播を改造します

    def backward(w):
        return -w  # 通常wが逆伝播されるところを-wにする

    return y, backward

この例では、$f(x)=x$という関数の逆伝播を変更しています。
$w$が出力側から伝わってきた更新方向とします。$f$の部分を逆伝播させると、通常は$wf'(x)$が新しい更新方向です。$f(x)=x$の場合では$w$です。例では、この更新方向を$-w$に変更しています。

forwardの返り値は2つの値からなるtupleです。それぞれ、

  • $f(x)$に相当する値
  • 自作の逆伝播関数(つまり、$w$に対して$f$の部分を逆伝播させた新しい更新方向を返す関数)

です。
custom.png

実例

では、実際にこの自作逆伝播の効果を見てみましょう。
forward(x)をGradientDescentOptimizerでminimizeしてみましょう。通常であれば、xの値が減っていくはずです。

x = tf.Variable(1.0, tf.float32)
y = forward(x)
train = tf.train.GradientDescentOptimizer(0.1).minimize(y)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10):
        print(i, sess.run(x))
        sess.run(train)

結果は次のようになります。

0 1.0
1 1.1
2 1.2
3 1.3000001
4 1.4000001
5 1.5000001
6 1.6000001
7 1.7000002
8 1.8000002
9 1.9000002

minimizeしているにも関わらず、値が増えています。逆向きに更新するようにした効果が確認できました。

備考

他の方法

https://qiita.com/jack_ama/items/8792ff2dcecfc90e029fで紹介されているようなtf.RegisterGradientでオペレーションを変更する方法は、通常のオペレーションを書き換える強引な方法で普通は避けた方がいいでしょう。

実行環境

  • python3.6.6
  • tensorflow1.11.0

リンク