0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

DeepCapsNetの実装(Keras)

Last updated at Posted at 2022-03-01

概要

DeepCapsNetをどうやって実装するのかをまとめてみました。

DeepCapsNetについて

CapsNet(Capusle Network)を深層化したNetwork。

CapsNetについては原論文と以下参照

CapsNetの実装(Keras) (拙著)

Understanding Hinton's Capsule Networks. Part 1. Intuition.

カプセルネットワークはニューラルネットワークを超えるか。 - Qiita

CapsNetについての調べ - Qiita

CapsNet (Capsule Network) の PyTorch 実装 - Qiita

CapsNetがConvNet(CNN)に勝るのかを毒キノコ画像判別で試してみた

実装

deepcaps_(1).png

上の図を参考に、入力部分から一つずつ実装していきます。

入力画像はMNIST(28, 28, 1)を想定しています。

Squash関数

各層の説明の前に、Squash関数の説明をさせてください。これはCapsNetで用いられる活性化関数です。

CapsNetにおけるカプセルの出力ベクトルの長さは、そのカプセルが表現する実体が現在の入力に存在する確率を表すようにしたいです。そこで、短いベクトルはほぼゼロの長さに縮小され、長いベクトルは1よりわずかに下の長さに縮小されるように、このSquash関数を用います。

$$
v_j=\frac{|s_j|^2}{1+|s_j|^2}\frac{s_j}{|s_j|}
$$

この出力ベクトルの長さが、カプセルと対応する特徴量が存在する確率として解釈されます。

def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

Squash3D関数

DeepCapsNetで用いるSquash関数です。CapsNetにおける入力ベクトル$s_j$は2次元ベクトルであったのに対し、DeepCapsNetの入力ベクトル$S_{pqr}$は3次元ベクトルです。違いはそれだけです。実装コードは同じです。

\begin{align*}
\hat{S}_{pqr} &= Squash\_3D(S_{pqr})\\
&=\frac{\|S_{pqr}\|^2}{1+\|S_{pqr}\|^2}\frac{S_{pqr}}{\|S_{pqr}\|}
\end{align*}

Softmax3D関数

3次元ベクトル対応のSoftmax関数です。

\begin{align*}
K_s &= Softmax\_3D(B_s)
\\
k_{pqrs} &= \frac{exp(b_{pqrs})}{\sum_x \sum_y \sum_z b_{xyzs}}
\end{align*}

入力層

これは一般的なDNNのものと同じです。

from keras.layers import Input

x = Input(shape=(28, 28, 1))
l = x

畳み込み層

deepcaps_(1) 1.png

DeepCapsNetでは、一般的なCNNモデルと同じく、最初に畳み込み層で入力画像を処理します。

from keras.layers import Conv2D

l = Conv2D(128, (3, 3), strides=(1, 1), activation='relu', padding="same")(l)  # common conv layer
l = BatchNormalization()(l)
l = ConvertToCaps()(l)

CapsCell 1~3

deepcaps_(1)_(1).png

CapsCell 1~3ではそれぞれ同様のことをしているので、まとめて説明します。

まず、ここで用いられているConvCaps layerについて。これは出力がスカッシュされた畳み込み層です。最終的な出力の形状は4次元テンソルで、$(H,W,C,Caps)$($Caps$:カプセル数)です。また、ルーティング数が1なので、これは単なる畳み込みと変わりません。

以下のように定義します。

def squeeze(s):
    sq = K.sum(K.square(s), axis=-1, keepdims=True)
    return (sq / (1 + sq)) * (s / K.sqrt(sq + K.epsilon()))

class Conv2DCaps(layers.Layer):

    def __init__(self, ch_j, n_j,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 r_num=1,
                 b_alphas=[8, 8, 8],
                 padding='same',
                 data_format='channels_last',
                 dilation_rate=(1, 1),
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 **kwargs):
        super(Conv2DCaps, self).__init__(**kwargs)
        rank = 2
        self.ch_j = ch_j  # Number of capsules in layer J
        self.n_j = n_j  # Number of neurons in a capsule in J
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
        self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
        self.r_num = r_num
        self.b_alphas = b_alphas
        self.padding = conv_utils.normalize_padding(padding)
        #self.data_format = conv_utils.normalize_data_format(data_format)
        self.data_format = K.normalize_data_format(data_format)
        self.dilation_rate = (1, 1)
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.input_spec = InputSpec(ndim=rank + 3)

    def build(self, input_shape):

        self.h_i, self.w_i, self.ch_i, self.n_i = input_shape[1:5]

        self.h_j, self.w_j = [conv_utils.conv_output_length(input_shape[i + 1],
                                                            self.kernel_size[i],
                                                            padding=self.padding,
                                                            stride=self.strides[i],
                                                            dilation=self.dilation_rate[i]) for i in (0, 1)]

        self.ah_j, self.aw_j = [conv_utils.conv_output_length(input_shape[i + 1],
                                                              self.kernel_size[i],
                                                              padding=self.padding,
                                                              stride=1,
                                                              dilation=self.dilation_rate[i]) for i in (0, 1)]

        self.w_shape = self.kernel_size + (self.ch_i, self.n_i,
                                           self.ch_j, self.n_j)

        self.w = self.add_weight(shape=self.w_shape,
                                 initializer=self.kernel_initializer,
                                 name='kernel',
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)

        self.built = True

    def call(self, inputs):
        if self.r_num == 1:
            # if there is no routing (and this is so when r_num is 1 and all c are equal)
            # then this is a common convolution
            outputs = K.conv2d(K.reshape(inputs, (-1, self.h_i, self.w_i,
                                                  self.ch_i * self.n_i)),
                               K.reshape(self.w, self.kernel_size +
                                         (self.ch_i * self.n_i, self.ch_j * self.n_j)),
                               data_format='channels_last',
                               strides=self.strides,
                               padding=self.padding,
                               dilation_rate=self.dilation_rate)

            outputs = squeeze(K.reshape(outputs, ((-1, self.h_j, self.w_j,
                                                   self.ch_j, self.n_j))))

        return outputs

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.h_j, self.w_j, self.ch_j, self.n_j)

    def get_config(self):
        config = {
            'ch_j': self.ch_j,
            'n_j': self.n_j,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'b_alphas': self.b_alphas,
            'padding': self.padding,
            'data_format': self.data_format,
            'dilation_rate': self.dilation_rate,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'activity_regularizer': regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint)
        }
        base_config = super(Conv2DCaps, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

squash()ではなくsqueeze()を使っている理由は、畳み込み層の出力がベクトルではなくテンソル(ベクター)だからだと思います。

CapsCell 1~3では(加算レイヤを除けば)このConvCaps layerのみを使用しています。ただ、スキップ接続という工夫はされています。これは、深層化による勾配消失を減少させることが狙いです。また、低レベルのカプセルを高レベルのカプセルと繋げることもできます。

スキップ接続後は2つのカプセルを加算レイヤで結合します。カプセルはベクトルで表現されるため、チャネル単位の連結は同じカプセルが重複するため使用しませんが、要素単位の加算はバイアスを減らし、ノイズの影響を受けにくくするそうです。

以上のことから、最終的には次のように実装されます。

# CapsCell 1
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])

# CapsCell 2
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])

# CapsCell 3
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
l1 = l

ちなみに、MNISTのような情報量の少ない画像に対してはCapsCell1~2のみで十分だと著者は言っています。

dynamic rounting 3D

DeepCapsNetでもdynamic rountingを使います。しかし、3Dバージョンです。以下のようなアルゴリズムが定義されています。

スクリーンショット_2022-03-01_8.42.18.png

実装は以下の通りです。

def update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim,
                   num_routing):
    if num_dims == 6:
        votes_t_shape = [3, 0, 1, 2, 4, 5]
        r_t_shape = [1, 2, 3, 0, 4, 5]
    elif num_dims == 4:
        votes_t_shape = [3, 0, 1, 2]
        r_t_shape = [1, 2, 3, 0]
    else:
        raise NotImplementedError('Not implemented')

    votes_trans = tf.transpose(votes, votes_t_shape)
    _, _, _, height, width, caps = votes_trans.get_shape()

    def _body(i, logits, activations):
        """Routing while loop."""
        # route: [batch, input_dim, output_dim, ...]
        a,b,c,d,e = logits.get_shape()
        a = logit_shape[0]
        b = logit_shape[1]
        c = logit_shape[2]
        d = logit_shape[3]
        e = logit_shape[4]
        print(logit_shape)
        logit_temp = tf.reshape(logits, [a,b,-1])
        route_temp = tf.nn.softmax(logit_temp, dim=-1)
        route = tf.reshape(route_temp, [a, b, c, d, e])
        preactivate_unrolled = route * votes_trans
        preact_trans = tf.transpose(preactivate_unrolled, r_t_shape)
        preactivate = tf.reduce_sum(preact_trans, axis=1) + biases
        # activation = _squash(preactivate)
        activation = squash(preactivate, axis=[-1, -2, -3])
        activations = activations.write(i, activation)

        act_3d = K.expand_dims(activation, 1)
        tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
        tile_shape[1] = input_dim
        act_replicated = tf.tile(act_3d, tile_shape)
        distances = tf.reduce_sum(votes * act_replicated, axis=3)
        logits += distances
        return (i + 1, logits, activations)

    activations = tf.TensorArray(
        dtype=tf.float32, size=num_routing, clear_after_read=False)
    logits = tf.fill(logit_shape, 0.0)

    i = tf.constant(0, dtype=tf.int32)
    _, logits, activations = tf.while_loop(
        lambda i, logits, activations: i < num_routing,
        _body,
        loop_vars=[i, logits, activations],
        swap_memory=True)
    a = K.cast(activations.read(num_routing - 1), dtype='float32')
    return K.cast(activations.read(num_routing - 1), dtype='float32')

CapsCell 4

deepcaps_(1)_(1) 1.png

CapsCell 4はCapsCell 1~3とは少し違い、ConvCaps3Dというレイヤを使用します。また、スキップ接続もルーティング数が3です。

ConvCaps3Dについて説明します。ConvCaps layerとの違いは、Conv2DではなくConv3Dを使用していること、それに合わせてsquash3D()を使っていること、またdynamic rounting3Dを3回まわしていることです。

定義は以下の通りです。

class ConvCapsuleLayer3D(layers.Layer):

    def __init__(self, kernel_size, num_capsule, num_atoms, strides=1, padding='valid', routings=3,
                 kernel_initializer='he_normal', **kwargs):
        super(ConvCapsuleLayer3D, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.num_capsule = num_capsule
        self.num_atoms = num_atoms
        self.strides = strides
        self.padding = padding
        self.routings = routings
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \
                                      " input_num_capsule, input_num_atoms]"
        self.input_height = input_shape[1]
        self.input_width = input_shape[2]
        self.input_num_capsule = input_shape[3]
        self.input_num_atoms = input_shape[4]

        # Transform matrix
        self.W = self.add_weight(shape=[self.input_num_atoms, self.kernel_size, self.kernel_size, 1, self.num_capsule * self.num_atoms],
                                 initializer=self.kernel_initializer,
                                 name='W')

        self.b = self.add_weight(shape=[self.num_capsule, self.num_atoms, 1, 1],
                                 initializer=initializers.constant(0.1),
                                 name='b')

        self.built = True

    def call(self, input_tensor, training=None):

        input_transposed = tf.transpose(input_tensor, [0, 3, 4, 1, 2])
        input_shape = K.shape(input_transposed)
        input_tensor_reshaped = K.reshape(input_tensor, [input_shape[0], 1, self.input_num_capsule * self.input_num_atoms, self.input_height, self.input_width])

        input_tensor_reshaped.set_shape((None, 1, self.input_num_capsule * self.input_num_atoms, self.input_height, self.input_width))

        # conv = Conv3D(input_tensor_reshaped, self.W, (self.strides, self.strides),
        #                 padding=self.padding, data_format='channels_first')

        conv = K.conv3d(input_tensor_reshaped, self.W, strides=(self.input_num_atoms, self.strides, self.strides), padding=self.padding, data_format='channels_first')

        votes_shape = K.shape(conv)
        _, _, _, conv_height, conv_width = conv.get_shape()
        conv = tf.transpose(conv, [0, 2, 1, 3, 4])
        votes = K.reshape(conv, [input_shape[0], self.input_num_capsule, self.num_capsule, self.num_atoms, votes_shape[3], votes_shape[4]])
        votes.set_shape((None, self.input_num_capsule, self.num_capsule, self.num_atoms, conv_height.value, conv_width.value))

        logit_shape = K.stack([input_shape[0], self.input_num_capsule, self.num_capsule, votes_shape[3], votes_shape[4]])
        biases_replicated = K.tile(self.b, [1, 1, conv_height.value, conv_width.value])

        activations = update_routing(
            votes=votes,
            biases=biases_replicated,
            logit_shape=logit_shape,
            num_dims=6,
            input_dim=self.input_num_capsule,
            output_dim=self.num_capsule,
            num_routing=self.routings)

        a2 = tf.transpose(activations, [0, 3, 4, 1, 2])
        return a2

    def compute_output_shape(self, input_shape):
        space = input_shape[1:-2]
        new_space = []
        for i in range(len(space)):
            new_dim = conv_output_length(space[i], self.kernel_size, padding=self.padding, stride=self.strides, dilation=1)
            new_space.append(new_dim)

        return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms)

    def get_config(self):
        config = {
            'kernel_size': self.kernel_size,
            'num_capsule': self.num_capsule,
            'num_atoms': self.num_atoms,
            'strides': self.strides,
            'padding': self.padding,
            'routings': self.routings,
            'kernel_initializer': initializers.serialize(self.kernel_initializer)
        }
        base_config = super(ConvCapsuleLayer3D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

最終的には以下のように実装されます。

l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = ConvCapsuleLayer3D(kernel_size=3, num_capsule=32, num_atoms=8, strides=1, padding='same', routings=3)(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
l2 = l

collect capsules

deepcaps_(1)_(1) 2.png

CapsCell 3とCapsCell 4の出力を結合させます。このステップの狙いは、多様なデータセットに対してモデルを一般化すること、みたいです。

結合する前に平坦化します。そのためにFlattenCaps layerを以下のように定義します。

class FlattenCaps(layers.Layer):

    def __init__(self, **kwargs):
        super(FlattenCaps, self).__init__(**kwargs)
        self.input_spec = InputSpec(min_ndim=4)

    def compute_output_shape(self, input_shape):
        if not all(input_shape[1:]):
            raise ValueError('The shape of the input to "FlattenCaps" '
                             'is not fully defined '
                             '(got ' + str(input_shape[1:]) + '. '
                             'Make sure to pass a complete "input_shape" '
                             'or "batch_input_shape" argument to the first '
                             'layer in your model.')
        return (input_shape[0], np.prod(input_shape[1:-1]), input_shape[-1])

    def call(self, inputs):
        shape = K.int_shape(inputs)
        return K.reshape(inputs, (-1, np.prod(shape[1:-1]), shape[-1]))

出力の形状は(バッチサイズ, カプセルサイズ, カプセルの個数)になっていると思います。

では、collect capsulesを以下のように実装できます。

la = FlattenCaps()(l2)
lb = FlattenCaps()(l1)
l = layers.Concatenate(axis=-2)([la, lb])

Flat Caps Layer(DigitCaps層)

deepcaps_(1)_(1) 3.png

次はFlat Caps Layerです。この層はCapsNetでいうDigitCaps層に当たると思います。

また、処理についてもCapsNetのPrimaryCaps層⇒DigitCaps層の処理と代わりません。dynamic routingを用います。定義は以下の通りです。

class CapsuleLayer(layers.Layer):

    def __init__(self, num_capsule, dim_capsule, channels, routings=3,
                 kernel_initializer='glorot_uniform',
                 **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.channels = channels
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]

        if(self.channels != 0):
            assert int(self.input_num_capsule / self.channels) / (self.input_num_capsule / self.channels) == 1, "error"
            self.W = self.add_weight(shape=[self.num_capsule, self.channels,
                                            self.dim_capsule, self.input_dim_capsule],
                                     initializer=self.kernel_initializer,
                                     name='W')

            self.B = self.add_weight(shape=[self.num_capsule, self.dim_capsule],
                                     initializer=self.kernel_initializer,
                                     name='B')
        else:
            self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                            self.dim_capsule, self.input_dim_capsule],
                                     initializer=self.kernel_initializer,
                                     name='W')
            self.B = self.add_weight(shape=[self.num_capsule, self.dim_capsule],
                                     initializer=self.kernel_initializer,
                                     name='B')

        self.built = True

    def call(self, inputs, training=None):
        # inputs.shape=[None, input_num_capsule, input_dim_capsule]
        # inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
        inputs_expand = K.expand_dims(inputs, 1)

        # Replicate num_capsule dimension to prepare being multiplied by W
        # inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
        inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])

        if(self.channels != 0):
            W2 = K.repeat_elements(self.W, int(self.input_num_capsule / self.channels), 1)
        else:
            W2 = self.W
        # Compute `inputs * W` by scanning inputs_tiled on dimension 0.
        # x.shape=[num_capsule, input_num_capsule, input_dim_capsule]
        # W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
        # Regard the first two dimensions as `batch` dimension,
        # then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
        # inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
        inputs_hat = K.map_fn(lambda x: own_batch_dot(x, W2, [2, 3]), elems=inputs_tiled)

        # Begin: Routing algorithm ---------------------------------------------------------------------#
        # The prior for coupling coefficient, initialized as zeros.
        # b.shape = [None, self.num_capsule, self.input_num_capsule].
        b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])

        assert self.routings > 0, 'The routings should be > 0.'
        for i in range(self.routings):
            # c.shape=[batch_size, num_capsule, input_num_capsule]
            c = tf.nn.softmax(b, dim=1)

            # c.shape =  [batch_size, num_capsule, input_num_capsule]
            # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
            # The first two dimensions as `batch` dimension,
            # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
            # outputs.shape=[None, num_capsule, dim_capsule]
            outputs = squash(own_batch_dot(c, inputs_hat, [2, 2]) + self.B)  # [None, 10, 16]

            if i < self.routings - 1:
                # outputs.shape =  [None, num_capsule, dim_capsule]
                # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
                # The first two dimensions as `batch` dimension,
                # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
                # b.shape=[batch_size, num_capsule, input_num_capsule]
                b += own_batch_dot(outputs, inputs_hat, [2, 3])
        # End: Routing algorithm -----------------------------------------------------------------------#

        return outputs

    def compute_output_shape(self, input_shape):
        return tuple([None, self.num_capsule, self.dim_capsule])

実装は以下の通りです。

digits_caps = CapsuleLayer(num_capsule=n_class, dim_capsule=32, routings=routings, channels=0, name='digit_caps')(l)

出力

deepcaps_(1)_(1) 4.png

カプセルはベクトルであり、その大きさは対象オブジェクトの存在確率を表します。そのため、最終出力のためにdigitCaps層の各カプセルを$L2$ノルムで[0, 1]の値にします。これが、各クラスに対する存在確率となります。

capsNet-CapsNetのコピー.drawio_(1).png

class CapsToScalars(layers.Layer):

    def __init__(self, **kwargs):
        super(CapsToScalars, self).__init__(**kwargs)
        self.input_spec = InputSpec(min_ndim=3)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1])

    def call(self, inputs):
        return K.sqrt(K.sum(K.square(inputs + K.epsilon()), axis=-1))

l = CapsToScalars(name='capsnet')(digits_caps)

Loss

CapsNetではMargin Lossと呼ばれる損失関数を使います。以下のような式です。

$$
L_k=T_k~max(0,m^+-|v_k|)^2+\lambda(1-T_k)~max(0,|v_k|-m^-)^2
$$

$T_k$はクラス$k$の存在確率($0~or~1$)を表します。また、原論文では$T_k=0$の時の損失の軽減率$\lambda$は$0.5$とし、$m^+=0.9,~m^-=0.1$としています。

def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

まとめ

kerasによるDeepCapsNetの実装の説明は以上になります。

本当はこの後Decoder部分もあります。簡単に説明しますと、DigitCaps層は各クラスの空間情報を持っていることから、その情報から各クラスの再構成します。その情報の取り出し方については、他のクラスをMaskする方法が原論文では提案されています。DeepCapsNetの提案論文ではそのクラス部分のみを取り出しています。

スクリーンショット_2022-03-01_9.15.12.png

Decoder部分については割愛します。

参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?