8
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Layer-wise Relevance Propagation を Chainer で実装する

Posted at

#1. Layer-wise Relevance Propagation(LRP)

元論文は これ で、詳しいことは以下参照。

Qiita:ディープラーニングの判断根拠を理解する手法

私はLRPを

一度データを順伝搬させて、出力層から各層の出力と重みを元に貢献度を求めていく手法

だと理解しています。

 

#2. Chainer

国産のニューラルネットワーク用のフレームワークです。

一度、順伝搬させるとデータが数珠つなぎ的に残るので、個人的には使いやすくて好きです。

Chainer: A flexible framework for neural networks

 

#3. 実装 (chainer v2.0.0)

Chainerは一度、順伝搬させるとデータと生成された層の種類を参照できるので、それを利用して実装しますバグなどがあるかもしれません。

下記のコードは最低限、線形結合と畳込み、プーリングのみに対応しています。
また、入力zのshapeは(データ数, 出力ニューロン数)を想定しています。

import chainer
import numpy as np

def LRP(z, epsilon=0):
    creator = z.creator
    var = z
    # relevance value
    r = np.zeros(z.data.shape)
    for i, d in enumerate(z.data):
        r[i, d.argmax()] = d.max()

    while(creator is not None):
        x = creator.inputs[0].data
        y = creator.outputs[0]().data

        if len(creator.inputs) > 1:
            w = creator.inputs[1].data
        if creator.label == "LinearFunction":
            _y = y + epsilon*np.sign(y)
            r = x.reshape(r.shape[0], -1) * (np.dot(r/_y, w))
        elif creator.label == "Convolution2DFunction":
            _y = y + epsilon*np.sign(y)
            r = x * chainer.functions.deconvolution_2d(r.reshape(y.shape)/_y,
                                                       w).data
        elif creator.label == "MaxPooling2D":
            r = chainer.functions.unpooling_2d(
                r.reshape(y.shape),
                ksize=creator.kh,
                stride=creator.sy,
                outsize=x.shape[2:]).data

        var = creator.inputs[0]
        creator = var.creator
    return r

一応、出力層のデータが消えないように以下のhookを順伝搬の際に追加しました。

from chainer.function import FunctionHook

class RetainOutputHook(FunctionHook):
    def forward_postprocess(self, function, in_data):
        function.retain_outputs([0])

''' Example
with RetainOutputHook():
    z = model.predict(x)
'''

4.出力例

下記のネットワーク(n_units=100)にMNISTを学習させて、入力に対するLRPを可視化しました。

class CNN(chainer.Chain):
    def __init__(self, n_units):
        super(CNN, self).__init__()
        with self.init_scope():
            self.conv1 = chainer.links.Convolution2D(in_channels=1, out_channels=n_units//2, ksize=3, stride=1)
            self.conv2 = chainer.links.Convolution2D(in_channels=None, out_channels=n_units, ksize=3, stride=1)
            self.l = chainer.links.Linear(None, 10)

    def __call__(self, x):
        x = chainer.functions.relu(self.conv1(x))
        x = chainer.functions.max_pooling_2d(x, ksize=2, stride=2)
        x = chainer.functions.relu(self.conv2(x))
        x = chainer.functions.max_pooling_2d(x, ksize=2, stride=2)
        return self.l(x)

左:入力画像 右:LRP
lrp.png

8
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?