概要
CapsNetをどうやって実装するのかをまとめてみました。
CapsNetについて
CapsNet(Capusle Network)はCNNに用いられているPoolig層は位置頑健性の問題があると指摘され、それを解決するために提案されたモデルです。詳しくは原論文と以下の記事を参照(とてもわかりやすいです!!)。
Understanding Hinton's Capsule Networks. Part 1. Intuition.
カプセルネットワークはニューラルネットワークを超えるか。 - Qiita
CapsNet (Capsule Network) の PyTorch 実装 - Qiita
CapsNetがConvNet(CNN)に勝るのかを毒キノコ画像判別で試してみた
実装
上の図を参考に、入力部分から一つずつ実装していきます。
入力画像はMNIST(28, 28, 1)を想定しています。
Squash関数
各層の説明の前に、Squash関数の説明をさせてください。これはCapsNetで用いられる活性化関数です。
CapsNetにおけるカプセルの出力ベクトルの長さは、そのカプセルが表現する実体が現在の入力に存在する確率を表すようにしたいです。そこで、短いベクトルはほぼゼロの長さに縮小され、長いベクトルは1よりわずかに下の長さに縮小されるように、このSquash関数を用います。
$$
v_j=\frac{|s_j|^2}{1+|s_j|^2}\frac{s_j}{|s_j|}
$$
この出力ベクトルの長さが、カプセルと対応する特徴量が存在する確率として解釈されます。
def squash(vectors, axis=-1):
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
return scale * vectors
入力層
これは一般的なDNNのものと同じです。
from keras.layers import Input
x = Input(shape=(28, 28, 1))
畳み込み層
CapsNetでは、一般的なCNNモデルと同じく、最初に畳み込み層で入力画像を処理します。
from keras.layers import Conv2D
conv1 = Conv2D(filters=256,
kernel_size=9,
strides=1,
padding='valid',
activation='relu',
name='conv1')(x)
PrimaryCaps層
PrimaryCaps層は名前の通りカプセルになります。ただし、低次のレイヤのカプセルにあたります。
例えば、入力画像が「顔」である場合(「顔」は分類するクラスの1つだとします)、同じレイヤに属する異なる3つの 低次の Capsule は、それぞれ「鼻」、「口」、「目」の特徴を表現します。
この実装では、以降の利用のためにPrimaryCaps層の形状は(None, 6, 6, 8, 256)ではなく、(None, 1152, 8)という形状で出力しています。形状の意味は(バッチサイズ, カプセル数, カプセルの次元数)です。
from keras.layers import Conv2D, Lambda, Reshape
from keras import backend as K
def squash(vectors, axis=-1):
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
return scale * vectors
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
# output.shape = [None, 6, 6, 256]
output = Conv2D(filters=dim_capsule*n_channels,
kernel_size=kernel_size,
strides=strides,
padding=padding,
name='primarycap_conv2d')(inputs)
# outputs.shape = [None, 6, 6, 8, 256]
outputs = Reshape(target_shape=[-1, dim_capsule],
name='primarycap_reshape')(output)
# shape = [None, num_capsule, dim_capsule]
# shape = [None, 1152, 8]
return Lambda(squash, name='primarycap_squash')(outputs)
primarycaps = PrimaryCap(conv1,
dim_capsule=8,
n_channels=32,
kernel_size=9,
strides=2,
padding='valid')(conv1)
dynamic routing
DigitCaps層の実装の前に、dynamic routingを説明します。この処理を簡単に言うと、PrimaryCaps層のカプセルに対する重みを計算する過程になります。一般的にDNNはバックプロバケーションを用いてそれを行なっており、CapsNetも例外ではありません。しかし、この部分だけdynamic rountingを用いています。余談になりますが、この処理の並列化がうまくできないことから、CapsNetは計算が遅いと言われています。
詳しく説明します。
PrimaryCaps層は低次レイヤです。「目」がある方向に存在する確率をベクトルとして持っています。「鼻」や「口」についても同様です。
それに対して、DigitCaps層は高次レイヤです。この例の場合、「顔」の空間的情報を持っていると考えられます。
dynamic rountingで行いたいことは、低次レイヤ(目、鼻、口)と高次レイヤ(顔)との相対的な空間情報から、「顔」が存在する確率を求めるようにすることです。以下の図のように、「目」の位置に対して顔がどの位置にどれだけの確率で存在するのかを予測したいと言うことです。
(出典:https://pechyonkin.me/capsules-2/)
そこで、以下のようなアルゴリズムが提案されました。これがdynamic rountingです。
PrimaryCaps層 - DigitCaps層間の計算過程の概略図を以下の図に示します。
(出典: https://github.com/naturomics/CapsNet-Tensorflow)
dynamic rountingでは、PraimaryCaps層のカプセル$u_i$を$w_{ij}$でアフィン変換した$\hat{u}_{j|i}$を用います。
$b_{ij}$は計算過程における一時的な重みです。最終的には$c_{ij}$に格納され、これが利用されます。
4行目で行っていることは、低次レイヤの各カプセルに対して割合をつけています。sofmax()を使っているため、その割合の合計は1になります。つまり、どの低次レイヤのカプセルが重要であるかを選択していることになります。
5行目では重みづけした低次レイヤの各カプセルの合計を取っています。
6行目ではsquash()によって$s_{j}$を「顔」の存在確率に変換します。これが高次レイヤの(中の1つの)カプセルになります。
7行目では重みを更新しています。$\hat{u}_{j|i}\cdot v_j$は入力ベクトル$\hat{u}_{j|i}$と出力ベクトル$v_j$の類似性を表しています。これを$b_{ij}$に加算することで更新させます。
DigitCaps層
PrimaryCaps層より高次のレイヤのカプセルであり、入力画像の特徴を持った層になります。具体的には、低次の異なる特徴間の相対的な位置関係なども含めた特徴を持っています。
そして、カプセル間には先述のdynamic routingの処理が行われています。
from keras import initializers
from keras import backend as K
import tensorflow as tf
class CapsuleLayer(layers.Layer):
def __init__(self, num_capsule, dim_capsule, routings=3,
kernel_initializer='glorot_uniform',
**kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]
# Transform matrix
self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.built = True
def call(self, inputs, training=None):
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
# inputs_expand.shape=[None, 1, 1152, 8]
inputs_expand = K.expand_dims(inputs, 1)
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
# inputs_tiled.shape=[None, 10, 11520, 8]
inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
# u^_ij <- u_ij・w_i
# then matmul:
# [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]),
elems=inputs_tiled)
# Begin: Routing algorithm ---------------------------------------------------------------------#
b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
for i in range(self.routings):
c = tf.nn.softmax(b, dim=1)
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16]
if i < self.routings - 1:
b += K.batch_dot(outputs, inputs_hat, [2, 3])
return outputs
# End: Routing algorithm -----------------------------------------------------------------------#
n_class = 10
routings = 3
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
name='digitcaps')(primarycaps)
出力
カプセルはベクトルであり、その大きさは対象オブジェクトの存在確率を表します。そのため、最終出力のためにdigitCaps層の各カプセルを$L2$ノルムで[0, 1]の値にします。これが、各クラスに対する存在確率となります。
class Length(layers.Layer):
"""
Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
inputs: shape=[None, num_vectors, dim_vector]
output: shape=[None, num_vectors]
"""
def call(self, inputs, **kwargs):
return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon())
def compute_output_shape(self, input_shape):
return input_shape[:-1]
def get_config(self):
config = super(Length, self).get_config()
return config
out_caps = Length(name='capsnet')(digitcaps)
Loss
CapsNetではMargin Lossと呼ばれる損失関数を使います。以下のような式です。
$$
L_k=T_k~max(0,m^+-|v_k|)^2+\lambda(1-T_k)~max(0,|v_k|-m^-)^2
$$
$T_k$はクラス$k$の存在確率($0~or~1$)を表します。また、原論文では$T_k=0$の時の損失の軽減率$\lambda$は$0.5$とし、$m^+=0.9,~m^-=0.1$としています。
def margin_loss(y_true, y_pred):
"""
Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
return K.mean(K.sum(L, 1))
まとめ
kerasによるCapsNetの実装の説明は以上になります。
本当はこの後Decoder部分もあります。簡単に説明しますと、DigitCaps層は各クラスの空間情報を持っていることから、その情報から各クラスの再構成します。その情報の取り出し方については、他のクラスをMaskする方法が原論文では提案されています。また、CapsNetを深層化したDeepCapsNetの提案論文ではそのクラス部分のみを取り出していました。
Decoder部分については当記事では割愛させていただきます。
参考にしたソースコード