2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CapsNetの実装(Keras) 説明

Last updated at Posted at 2022-02-18

概要

CapsNetをどうやって実装するのかをまとめてみました。

CapsNetについて

CapsNet(Capusle Network)はCNNに用いられているPoolig層は位置頑健性の問題があると指摘され、それを解決するために提案されたモデルです。詳しくは原論文と以下の記事を参照(とてもわかりやすいです!!)。

Understanding Hinton's Capsule Networks. Part 1. Intuition.

カプセルネットワークはニューラルネットワークを超えるか。 - Qiita

CapsNetについての調べ - Qiita

CapsNet (Capsule Network) の PyTorch 実装 - Qiita

CapsNetがConvNet(CNN)に勝るのかを毒キノコ画像判別で試してみた

実装

Untitled.png

capsNet-CapsNetのコピー.drawio_(1).png

上の図を参考に、入力部分から一つずつ実装していきます。

入力画像はMNIST(28, 28, 1)を想定しています。

Squash関数

各層の説明の前に、Squash関数の説明をさせてください。これはCapsNetで用いられる活性化関数です。

CapsNetにおけるカプセルの出力ベクトルの長さは、そのカプセルが表現する実体が現在の入力に存在する確率を表すようにしたいです。そこで、短いベクトルはほぼゼロの長さに縮小され、長いベクトルは1よりわずかに下の長さに縮小されるように、このSquash関数を用います。

$$
v_j=\frac{|s_j|^2}{1+|s_j|^2}\frac{s_j}{|s_j|}
$$

この出力ベクトルの長さが、カプセルと対応する特徴量が存在する確率として解釈されます。

def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

入力層

これは一般的なDNNのものと同じです。

from keras.layers import Input

x = Input(shape=(28, 28, 1))

畳み込み層

Untitled_(1).png

capsNet-CapsNetのコピー.drawio_(1) 1.png

CapsNetでは、一般的なCNNモデルと同じく、最初に畳み込み層で入力画像を処理します。

from keras.layers import Conv2D

conv1 = Conv2D(filters=256, 
							  kernel_size=9, 
								strides=1,
							  padding='valid', 
								activation='relu', 
								name='conv1')(x)

PrimaryCaps層

Untitled_(1) 1.png

capsNet-CapsNetのコピー.drawio_(1) 2.png

PrimaryCaps層は名前の通りカプセルになります。ただし、低次のレイヤのカプセルにあたります。

例えば、入力画像が「顔」である場合(「顔」は分類するクラスの1つだとします)、同じレイヤに属する異なる3つの 低次の Capsule は、それぞれ「鼻」、「口」、「目」の特徴を表現します。

この実装では、以降の利用のためにPrimaryCaps層の形状は(None, 6, 6, 8, 256)ではなく、(None, 1152, 8)という形状で出力しています。形状の意味は(バッチサイズ, カプセル数, カプセルの次元数)です。

from keras.layers import Conv2D, Lambda, Reshape
from keras import backend as K

def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
    # output.shape = [None, 6, 6, 256]
		output = Conv2D(filters=dim_capsule*n_channels, 
										kernel_size=kernel_size, 
										strides=strides, 
										padding=padding,
	                  name='primarycap_conv2d')(inputs)
    # outputs.shape = [None, 6, 6, 8, 256]
    outputs = Reshape(target_shape=[-1, dim_capsule], 
											name='primarycap_reshape')(output)
		# shape = [None, num_capsule, dim_capsule]
    # shape = [None, 1152, 8]
    return Lambda(squash, name='primarycap_squash')(outputs)

primarycaps = PrimaryCap(conv1, 
												dim_capsule=8, 
										    n_channels=32, 
											  kernel_size=9, 
											  strides=2, 
										 	  padding='valid')(conv1)

dynamic routing

DigitCaps層の実装の前に、dynamic routingを説明します。この処理を簡単に言うと、PrimaryCaps層のカプセルに対する重みを計算する過程になります。一般的にDNNはバックプロバケーションを用いてそれを行なっており、CapsNetも例外ではありません。しかし、この部分だけdynamic rountingを用いています。余談になりますが、この処理の並列化がうまくできないことから、CapsNetは計算が遅いと言われています。

詳しく説明します。

PrimaryCaps層は低次レイヤです。「目」がある方向に存在する確率をベクトルとして持っています。「鼻」や「口」についても同様です。

それに対して、DigitCaps層は高次レイヤです。この例の場合、「顔」の空間的情報を持っていると考えられます。

dynamic rountingで行いたいことは、低次レイヤ(目、鼻、口)と高次レイヤ(顔)との相対的な空間情報から、「顔」が存在する確率を求めるようにすることです。以下の図のように、「目」の位置に対して顔がどの位置にどれだけの確率で存在するのかを予測したいと言うことです。

Untitled 1.png

(出典:https://pechyonkin.me/capsules-2/)

そこで、以下のようなアルゴリズムが提案されました。これがdynamic rountingです。

Untitled 2.png

PrimaryCaps層 - DigitCaps層間の計算過程の概略図を以下の図に示します。

Untitled_(4).png

(出典: https://github.com/naturomics/CapsNet-Tensorflow)

dynamic rountingでは、PraimaryCaps層のカプセル$u_i$を$w_{ij}$でアフィン変換した$\hat{u}_{j|i}$を用います。

$b_{ij}$は計算過程における一時的な重みです。最終的には$c_{ij}$に格納され、これが利用されます。

4行目で行っていることは、低次レイヤの各カプセルに対して割合をつけています。sofmax()を使っているため、その割合の合計は1になります。つまり、どの低次レイヤのカプセルが重要であるかを選択していることになります。

5行目では重みづけした低次レイヤの各カプセルの合計を取っています。

6行目ではsquash()によって$s_{j}$を「顔」の存在確率に変換します。これが高次レイヤの(中の1つの)カプセルになります。

7行目では重みを更新しています。$\hat{u}_{j|i}\cdot v_j$は入力ベクトル$\hat{u}_{j|i}$と出力ベクトル$v_j$の類似性を表しています。これを$b_{ij}$に加算することで更新させます。

DigitCaps層

Untitled_(3).png

capsNet-CapsNetのコピー.drawio_(1) 3.png

PrimaryCaps層より高次のレイヤのカプセルであり、入力画像の特徴を持った層になります。具体的には、低次の異なる特徴間の相対的な位置関係なども含めた特徴を持っています。

そして、カプセル間には先述のdynamic routingの処理が行われています。

from keras import initializers
from keras import backend as K
import tensorflow as tf

class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsule, dim_capsule, routings=3,
                 kernel_initializer='glorot_uniform',
                 **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
				self.kernel_initializer = initializers.get(kernel_initializer)

		def build(self, input_shape):
        assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
        self.input_num_capsule = input_shape[1]
        self.input_dim_capsule = input_shape[2]
        # Transform matrix
        self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
                                        self.dim_capsule, self.input_dim_capsule],
                                 initializer=self.kernel_initializer,
                                 name='W')
        self.built = True

    def call(self, inputs, training=None):
				# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
        # inputs_expand.shape=[None, 1, 1152, 8]
        inputs_expand = K.expand_dims(inputs, 1)
				# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
				# inputs_tiled.shape=[None, 10, 11520, 8]
        inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
				# u^_ij <- u_ij・w_i
        # then matmul: 
				# [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
        # inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
        inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), 
															elems=inputs_tiled)
        # Begin: Routing algorithm ---------------------------------------------------------------------#
				b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
        for i in range(self.routings):
            c = tf.nn.softmax(b, dim=1)
            outputs = squash(K.batch_dot(c, inputs_hat, [2, 2]))  # [None, 10, 16]
            if i < self.routings - 1:
                b += K.batch_dot(outputs, inputs_hat, [2, 3])
        return outputs
				# End: Routing algorithm -----------------------------------------------------------------------#

n_class = 10
routings = 3

digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
                             name='digitcaps')(primarycaps)

出力

カプセルはベクトルであり、その大きさは対象オブジェクトの存在確率を表します。そのため、最終出力のためにdigitCaps層の各カプセルを$L2$ノルムで[0, 1]の値にします。これが、各クラスに対する存在確率となります。

capsNet-CapsNetのコピー.drawio_(1) 4.png

class Length(layers.Layer):
    """
    Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
    Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
    inputs: shape=[None, num_vectors, dim_vector]
    output: shape=[None, num_vectors]
    """
    def call(self, inputs, **kwargs):
        return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon())

    def compute_output_shape(self, input_shape):
        return input_shape[:-1]

    def get_config(self):
        config = super(Length, self).get_config()
        return config

out_caps = Length(name='capsnet')(digitcaps)

Loss

CapsNetではMargin Lossと呼ばれる損失関数を使います。以下のような式です。

$$
L_k=T_k~max(0,m^+-|v_k|)^2+\lambda(1-T_k)~max(0,|v_k|-m^-)^2
$$

$T_k$はクラス$k$の存在確率($0~or~1$)を表します。また、原論文では$T_k=0$の時の損失の軽減率$\lambda$は$0.5$とし、$m^+=0.9,~m^-=0.1$としています。

def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

    return K.mean(K.sum(L, 1))

まとめ

kerasによるCapsNetの実装の説明は以上になります。

本当はこの後Decoder部分もあります。簡単に説明しますと、DigitCaps層は各クラスの空間情報を持っていることから、その情報から各クラスの再構成します。その情報の取り出し方については、他のクラスをMaskする方法が原論文では提案されています。また、CapsNetを深層化したDeepCapsNetの提案論文ではそのクラス部分のみを取り出していました。

Decoder部分については当記事では割愛させていただきます。

参考にしたソースコード

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?