Help us understand the problem. What is going on with this article?

[Keras] モデルのヘシアンと勾配を計算する

はじめに

この記事ではKerasでヘシアン(ヘッセ行列)を計算する方法について解説します。
ヘシアンは自然勾配を利用した効率的な学習や、モデルの圧縮(Optimal brain damage等)などに利用されています。
計算量の観点からヘシアンがそのまま利用されることは稀ですが、小さいモデルでは現実的な時間で求まるので、今回はその方法について説明していきます。
ソースコードはこちらからどうぞ

環境

version
Python 3.6.9
Keras 2.2.5
TensorFlow 1.14.0

ヘシアンを計算する対象のモデル

今回ヘシアンを計算する対象は、非常に小さい全結合層のモデルです。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Dense1 (Dense)               (None, 2)                 4         
_________________________________________________________________
Dense2 (Dense)               (None, 2)                 4         
=================================================================
Total params: 8
Trainable params: 8
Non-trainable params: 0
_________________________________________________________________

簡単のためにバイアスは利用していません。活性化関数はReLUで、最終層はクラス分類問題を想定して、Softmaxを利用しています。
今回は学習をしないので、Modelset_weights()メソッドで重みの値を適当にセットしておきます。

def build_model() -> Model:
    model = Sequential([
        Dense(2, use_bias=False, activation="relu", name="Dense1", input_shape=(2,)),
        Dense(2, use_bias=False, activation="softmax", name="Dense2")
    ])
    model.summary()
    model.set_weights([
        np.array([[1, 2], [3, 4]]),  # Dense1の重み
        np.array([[6, 5], [7, 8]]),  # Dense2の重み
    ])
    return model

ヘシアンの計算

必要なimportは以下の通りです。

import keras.backend as K
import numpy as np
import tensorflow as tf
from keras import Model, Sequential
from keras.layers import Dense

ヘシアンを計算するためにまず、Lossを返すテンソルを作成します。
Modelに設定されたOptimizerからLossを返すテンソルを取得する方法がわからなかったので(そもそも出来ないかも)、Kerasのバックエンド関数を活用して自分で作っていきます。

model = build_model()
y_true = K.placeholder((None, 2,))  # one-hotなラベルを入れるPlaceholder
loss = K.categorical_crossentropy(y_true, model.output)  # model.outputはSoftmax後の値

Kerasのバックエンド関数ではヘシアンを計算出来ないので、やむを得ずTensorFlowの関数を使います。
大多数の人はバックエンドにTensorFlowを利用していると思われるので、問題は無いはず。
また、tf.hessians()は4次元のテンソルを返すので、見やすいように2次元テンソルに変形します。

hessian = tf.hessians(loss, model.get_layer("Dense1").kernel)[0]
s = hessian.shape
hessian = K.reshape(hessian, [s[0] * s[1], s[2] * s[3]])  # 4次元テンソルを2次元に整形

最後にeval()メソッドにfeed_dictとしてモデルのインプットとラベルを渡してあげて、ヘシアンの計算をします。

inputs = np.array([[1, 2]])
labels = np.array([[0, 1]])

with K.get_session():
    print(hessian.eval({model.input: inputs, y_true: labels}))
結果
[[ 0.04517667 -0.04517668  0.09035334 -0.09035337]
 [-0.04517668  0.04517663 -0.09035337  0.09035325]
 [ 0.09035334 -0.09035337  0.18070668 -0.18070674]
 [-0.09035337  0.09035325 -0.18070674  0.1807065 ]]

勾配の計算

同様の手順でモデルの勾配も計算できます。

gradient = K.gradients(loss, model.get_layer("Dense1").kernel)[0]

with K.get_session():
    print(gradient.eval({model.input: inputs, y_true: labels}))
結果
[[ 0.0474259  -0.04742587]
 [ 0.09485179 -0.09485173]]

おわりに

コードを見れば分かることですが、かなりシンプルに計算が出来ました。
ですが、Kerasのバックエンド関数が絡んでくる様になると、GitHubのコードを読んで内部を理解したりとかなり骨が折れます。
placeholderfeed_dictを利用したテンソルの評価や、Modelからレイヤを取得する方法、モデルの重みの取得方法や設定方法など、それなりに参考になるかと思います。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away