概要
DeepCapsNetをどうやって実装するのかをまとめてみました。
DeepCapsNetについて
CapsNet(Capusle Network)を深層化したNetwork。
CapsNetについては原論文と以下参照
Understanding Hinton's Capsule Networks. Part 1. Intuition.
カプセルネットワークはニューラルネットワークを超えるか。 - Qiita
CapsNet (Capsule Network) の PyTorch 実装 - Qiita
CapsNetがConvNet(CNN)に勝るのかを毒キノコ画像判別で試してみた
実装
上の図を参考に、入力部分から一つずつ実装していきます。
入力画像はMNIST(28, 28, 1)を想定しています。
Squash関数
各層の説明の前に、Squash関数の説明をさせてください。これはCapsNetで用いられる活性化関数です。
CapsNetにおけるカプセルの出力ベクトルの長さは、そのカプセルが表現する実体が現在の入力に存在する確率を表すようにしたいです。そこで、短いベクトルはほぼゼロの長さに縮小され、長いベクトルは1よりわずかに下の長さに縮小されるように、このSquash関数を用います。
$$
v_j=\frac{|s_j|^2}{1+|s_j|^2}\frac{s_j}{|s_j|}
$$
この出力ベクトルの長さが、カプセルと対応する特徴量が存在する確率として解釈されます。
def squash(vectors, axis=-1):
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
return scale * vectors
Squash3D関数
DeepCapsNetで用いるSquash関数です。CapsNetにおける入力ベクトル$s_j$は2次元ベクトルであったのに対し、DeepCapsNetの入力ベクトル$S_{pqr}$は3次元ベクトルです。違いはそれだけです。実装コードは同じです。
\begin{align*}
\hat{S}_{pqr} &= Squash\_3D(S_{pqr})\\
&=\frac{\|S_{pqr}\|^2}{1+\|S_{pqr}\|^2}\frac{S_{pqr}}{\|S_{pqr}\|}
\end{align*}
Softmax3D関数
3次元ベクトル対応のSoftmax関数です。
\begin{align*}
K_s &= Softmax\_3D(B_s)
\\
k_{pqrs} &= \frac{exp(b_{pqrs})}{\sum_x \sum_y \sum_z b_{xyzs}}
\end{align*}
入力層
これは一般的なDNNのものと同じです。
from keras.layers import Input
x = Input(shape=(28, 28, 1))
l = x
畳み込み層
DeepCapsNetでは、一般的なCNNモデルと同じく、最初に畳み込み層で入力画像を処理します。
from keras.layers import Conv2D
l = Conv2D(128, (3, 3), strides=(1, 1), activation='relu', padding="same")(l) # common conv layer
l = BatchNormalization()(l)
l = ConvertToCaps()(l)
CapsCell 1~3
CapsCell 1~3ではそれぞれ同様のことをしているので、まとめて説明します。
まず、ここで用いられているConvCaps layerについて。これは出力がスカッシュされた畳み込み層です。最終的な出力の形状は4次元テンソルで、$(H,W,C,Caps)$($Caps$:カプセル数)です。また、ルーティング数が1なので、これは単なる畳み込みと変わりません。
以下のように定義します。
def squeeze(s):
sq = K.sum(K.square(s), axis=-1, keepdims=True)
return (sq / (1 + sq)) * (s / K.sqrt(sq + K.epsilon()))
class Conv2DCaps(layers.Layer):
def __init__(self, ch_j, n_j,
kernel_size=(3, 3),
strides=(1, 1),
r_num=1,
b_alphas=[8, 8, 8],
padding='same',
data_format='channels_last',
dilation_rate=(1, 1),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
**kwargs):
super(Conv2DCaps, self).__init__(**kwargs)
rank = 2
self.ch_j = ch_j # Number of capsules in layer J
self.n_j = n_j # Number of neurons in a capsule in J
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.r_num = r_num
self.b_alphas = b_alphas
self.padding = conv_utils.normalize_padding(padding)
#self.data_format = conv_utils.normalize_data_format(data_format)
self.data_format = K.normalize_data_format(data_format)
self.dilation_rate = (1, 1)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.input_spec = InputSpec(ndim=rank + 3)
def build(self, input_shape):
self.h_i, self.w_i, self.ch_i, self.n_i = input_shape[1:5]
self.h_j, self.w_j = [conv_utils.conv_output_length(input_shape[i + 1],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i]) for i in (0, 1)]
self.ah_j, self.aw_j = [conv_utils.conv_output_length(input_shape[i + 1],
self.kernel_size[i],
padding=self.padding,
stride=1,
dilation=self.dilation_rate[i]) for i in (0, 1)]
self.w_shape = self.kernel_size + (self.ch_i, self.n_i,
self.ch_j, self.n_j)
self.w = self.add_weight(shape=self.w_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.built = True
def call(self, inputs):
if self.r_num == 1:
# if there is no routing (and this is so when r_num is 1 and all c are equal)
# then this is a common convolution
outputs = K.conv2d(K.reshape(inputs, (-1, self.h_i, self.w_i,
self.ch_i * self.n_i)),
K.reshape(self.w, self.kernel_size +
(self.ch_i * self.n_i, self.ch_j * self.n_j)),
data_format='channels_last',
strides=self.strides,
padding=self.padding,
dilation_rate=self.dilation_rate)
outputs = squeeze(K.reshape(outputs, ((-1, self.h_j, self.w_j,
self.ch_j, self.n_j))))
return outputs
def compute_output_shape(self, input_shape):
return (input_shape[0], self.h_j, self.w_j, self.ch_j, self.n_j)
def get_config(self):
config = {
'ch_j': self.ch_j,
'n_j': self.n_j,
'kernel_size': self.kernel_size,
'strides': self.strides,
'b_alphas': self.b_alphas,
'padding': self.padding,
'data_format': self.data_format,
'dilation_rate': self.dilation_rate,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint)
}
base_config = super(Conv2DCaps, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
squash()ではなくsqueeze()を使っている理由は、畳み込み層の出力がベクトルではなくテンソル(ベクター)だからだと思います。
CapsCell 1~3では(加算レイヤを除けば)このConvCaps layerのみを使用しています。ただ、スキップ接続という工夫はされています。これは、深層化による勾配消失を減少させることが狙いです。また、低レベルのカプセルを高レベルのカプセルと繋げることもできます。
スキップ接続後は2つのカプセルを加算レイヤで結合します。カプセルはベクトルで表現されるため、チャネル単位の連結は同じカプセルが重複するため使用しませんが、要素単位の加算はバイアスを減らし、ノイズの影響を受けにくくするそうです。
以上のことから、最終的には次のように実装されます。
# CapsCell 1
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 4, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
# CapsCell 2
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
# CapsCell 3
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
l1 = l
ちなみに、MNISTのような情報量の少ない画像に対してはCapsCell1~2のみで十分だと著者は言っています。
dynamic rounting 3D
DeepCapsNetでもdynamic rountingを使います。しかし、3Dバージョンです。以下のようなアルゴリズムが定義されています。
実装は以下の通りです。
def update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim,
num_routing):
if num_dims == 6:
votes_t_shape = [3, 0, 1, 2, 4, 5]
r_t_shape = [1, 2, 3, 0, 4, 5]
elif num_dims == 4:
votes_t_shape = [3, 0, 1, 2]
r_t_shape = [1, 2, 3, 0]
else:
raise NotImplementedError('Not implemented')
votes_trans = tf.transpose(votes, votes_t_shape)
_, _, _, height, width, caps = votes_trans.get_shape()
def _body(i, logits, activations):
"""Routing while loop."""
# route: [batch, input_dim, output_dim, ...]
a,b,c,d,e = logits.get_shape()
a = logit_shape[0]
b = logit_shape[1]
c = logit_shape[2]
d = logit_shape[3]
e = logit_shape[4]
print(logit_shape)
logit_temp = tf.reshape(logits, [a,b,-1])
route_temp = tf.nn.softmax(logit_temp, dim=-1)
route = tf.reshape(route_temp, [a, b, c, d, e])
preactivate_unrolled = route * votes_trans
preact_trans = tf.transpose(preactivate_unrolled, r_t_shape)
preactivate = tf.reduce_sum(preact_trans, axis=1) + biases
# activation = _squash(preactivate)
activation = squash(preactivate, axis=[-1, -2, -3])
activations = activations.write(i, activation)
act_3d = K.expand_dims(activation, 1)
tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
tile_shape[1] = input_dim
act_replicated = tf.tile(act_3d, tile_shape)
distances = tf.reduce_sum(votes * act_replicated, axis=3)
logits += distances
return (i + 1, logits, activations)
activations = tf.TensorArray(
dtype=tf.float32, size=num_routing, clear_after_read=False)
logits = tf.fill(logit_shape, 0.0)
i = tf.constant(0, dtype=tf.int32)
_, logits, activations = tf.while_loop(
lambda i, logits, activations: i < num_routing,
_body,
loop_vars=[i, logits, activations],
swap_memory=True)
a = K.cast(activations.read(num_routing - 1), dtype='float32')
return K.cast(activations.read(num_routing - 1), dtype='float32')
CapsCell 4
CapsCell 4はCapsCell 1~3とは少し違い、ConvCaps3Dというレイヤを使用します。また、スキップ接続もルーティング数が3です。
ConvCaps3Dについて説明します。ConvCaps layerとの違いは、Conv2DではなくConv3Dを使用していること、それに合わせてsquash3D()を使っていること、またdynamic rounting3Dを3回まわしていることです。
定義は以下の通りです。
class ConvCapsuleLayer3D(layers.Layer):
def __init__(self, kernel_size, num_capsule, num_atoms, strides=1, padding='valid', routings=3,
kernel_initializer='he_normal', **kwargs):
super(ConvCapsuleLayer3D, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.num_capsule = num_capsule
self.num_atoms = num_atoms
self.strides = strides
self.padding = padding
self.routings = routings
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \
" input_num_capsule, input_num_atoms]"
self.input_height = input_shape[1]
self.input_width = input_shape[2]
self.input_num_capsule = input_shape[3]
self.input_num_atoms = input_shape[4]
# Transform matrix
self.W = self.add_weight(shape=[self.input_num_atoms, self.kernel_size, self.kernel_size, 1, self.num_capsule * self.num_atoms],
initializer=self.kernel_initializer,
name='W')
self.b = self.add_weight(shape=[self.num_capsule, self.num_atoms, 1, 1],
initializer=initializers.constant(0.1),
name='b')
self.built = True
def call(self, input_tensor, training=None):
input_transposed = tf.transpose(input_tensor, [0, 3, 4, 1, 2])
input_shape = K.shape(input_transposed)
input_tensor_reshaped = K.reshape(input_tensor, [input_shape[0], 1, self.input_num_capsule * self.input_num_atoms, self.input_height, self.input_width])
input_tensor_reshaped.set_shape((None, 1, self.input_num_capsule * self.input_num_atoms, self.input_height, self.input_width))
# conv = Conv3D(input_tensor_reshaped, self.W, (self.strides, self.strides),
# padding=self.padding, data_format='channels_first')
conv = K.conv3d(input_tensor_reshaped, self.W, strides=(self.input_num_atoms, self.strides, self.strides), padding=self.padding, data_format='channels_first')
votes_shape = K.shape(conv)
_, _, _, conv_height, conv_width = conv.get_shape()
conv = tf.transpose(conv, [0, 2, 1, 3, 4])
votes = K.reshape(conv, [input_shape[0], self.input_num_capsule, self.num_capsule, self.num_atoms, votes_shape[3], votes_shape[4]])
votes.set_shape((None, self.input_num_capsule, self.num_capsule, self.num_atoms, conv_height.value, conv_width.value))
logit_shape = K.stack([input_shape[0], self.input_num_capsule, self.num_capsule, votes_shape[3], votes_shape[4]])
biases_replicated = K.tile(self.b, [1, 1, conv_height.value, conv_width.value])
activations = update_routing(
votes=votes,
biases=biases_replicated,
logit_shape=logit_shape,
num_dims=6,
input_dim=self.input_num_capsule,
output_dim=self.num_capsule,
num_routing=self.routings)
a2 = tf.transpose(activations, [0, 3, 4, 1, 2])
return a2
def compute_output_shape(self, input_shape):
space = input_shape[1:-2]
new_space = []
for i in range(len(space)):
new_dim = conv_output_length(space[i], self.kernel_size, padding=self.padding, stride=self.strides, dilation=1)
new_space.append(new_dim)
return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms)
def get_config(self):
config = {
'kernel_size': self.kernel_size,
'num_capsule': self.num_capsule,
'num_atoms': self.num_atoms,
'strides': self.strides,
'padding': self.padding,
'routings': self.routings,
'kernel_initializer': initializers.serialize(self.kernel_initializer)
}
base_config = super(ConvCapsuleLayer3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
最終的には以下のように実装されます。
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(2, 2), r_num=1, b_alphas=[1, 1, 1])(l)
l_skip = ConvCapsuleLayer3D(kernel_size=3, num_capsule=32, num_atoms=8, strides=1, padding='same', routings=3)(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = Conv2DCaps(32, 8, kernel_size=(3, 3), strides=(1, 1), r_num=1, b_alphas=[1, 1, 1])(l)
l = layers.Add()([l, l_skip])
l2 = l
collect capsules
CapsCell 3とCapsCell 4の出力を結合させます。このステップの狙いは、多様なデータセットに対してモデルを一般化すること、みたいです。
結合する前に平坦化します。そのためにFlattenCaps layerを以下のように定義します。
class FlattenCaps(layers.Layer):
def __init__(self, **kwargs):
super(FlattenCaps, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=4)
def compute_output_shape(self, input_shape):
if not all(input_shape[1:]):
raise ValueError('The shape of the input to "FlattenCaps" '
'is not fully defined '
'(got ' + str(input_shape[1:]) + '. '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.')
return (input_shape[0], np.prod(input_shape[1:-1]), input_shape[-1])
def call(self, inputs):
shape = K.int_shape(inputs)
return K.reshape(inputs, (-1, np.prod(shape[1:-1]), shape[-1]))
出力の形状は(バッチサイズ, カプセルサイズ, カプセルの個数)になっていると思います。
では、collect capsulesを以下のように実装できます。
la = FlattenCaps()(l2)
lb = FlattenCaps()(l1)
l = layers.Concatenate(axis=-2)([la, lb])
Flat Caps Layer(DigitCaps層)
次はFlat Caps Layerです。この層はCapsNetでいうDigitCaps層に当たると思います。
また、処理についてもCapsNetのPrimaryCaps層⇒DigitCaps層の処理と代わりません。dynamic routingを用います。定義は以下の通りです。
class CapsuleLayer(layers.Layer):
def __init__(self, num_capsule, dim_capsule, channels, routings=3,
kernel_initializer='glorot_uniform',
**kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.channels = channels
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]
if(self.channels != 0):
assert int(self.input_num_capsule / self.channels) / (self.input_num_capsule / self.channels) == 1, "error"
self.W = self.add_weight(shape=[self.num_capsule, self.channels,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.B = self.add_weight(shape=[self.num_capsule, self.dim_capsule],
initializer=self.kernel_initializer,
name='B')
else:
self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.B = self.add_weight(shape=[self.num_capsule, self.dim_capsule],
initializer=self.kernel_initializer,
name='B')
self.built = True
def call(self, inputs, training=None):
# inputs.shape=[None, input_num_capsule, input_dim_capsule]
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
inputs_expand = K.expand_dims(inputs, 1)
# Replicate num_capsule dimension to prepare being multiplied by W
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
if(self.channels != 0):
W2 = K.repeat_elements(self.W, int(self.input_num_capsule / self.channels), 1)
else:
W2 = self.W
# Compute `inputs * W` by scanning inputs_tiled on dimension 0.
# x.shape=[num_capsule, input_num_capsule, input_dim_capsule]
# W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
# Regard the first two dimensions as `batch` dimension,
# then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: own_batch_dot(x, W2, [2, 3]), elems=inputs_tiled)
# Begin: Routing algorithm ---------------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
# b.shape = [None, self.num_capsule, self.input_num_capsule].
b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
# c.shape=[batch_size, num_capsule, input_num_capsule]
c = tf.nn.softmax(b, dim=1)
# c.shape = [batch_size, num_capsule, input_num_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(own_batch_dot(c, inputs_hat, [2, 2]) + self.B) # [None, 10, 16]
if i < self.routings - 1:
# outputs.shape = [None, num_capsule, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
# b.shape=[batch_size, num_capsule, input_num_capsule]
b += own_batch_dot(outputs, inputs_hat, [2, 3])
# End: Routing algorithm -----------------------------------------------------------------------#
return outputs
def compute_output_shape(self, input_shape):
return tuple([None, self.num_capsule, self.dim_capsule])
実装は以下の通りです。
digits_caps = CapsuleLayer(num_capsule=n_class, dim_capsule=32, routings=routings, channels=0, name='digit_caps')(l)
出力
カプセルはベクトルであり、その大きさは対象オブジェクトの存在確率を表します。そのため、最終出力のためにdigitCaps層の各カプセルを$L2$ノルムで[0, 1]の値にします。これが、各クラスに対する存在確率となります。
class CapsToScalars(layers.Layer):
def __init__(self, **kwargs):
super(CapsToScalars, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1])
def call(self, inputs):
return K.sqrt(K.sum(K.square(inputs + K.epsilon()), axis=-1))
l = CapsToScalars(name='capsnet')(digits_caps)
Loss
CapsNetではMargin Lossと呼ばれる損失関数を使います。以下のような式です。
$$
L_k=T_k~max(0,m^+-|v_k|)^2+\lambda(1-T_k)~max(0,|v_k|-m^-)^2
$$
$T_k$はクラス$k$の存在確率($0~or~1$)を表します。また、原論文では$T_k=0$の時の損失の軽減率$\lambda$は$0.5$とし、$m^+=0.9,~m^-=0.1$としています。
def margin_loss(y_true, y_pred):
"""
Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
return K.mean(K.sum(L, 1))
まとめ
kerasによるDeepCapsNetの実装の説明は以上になります。
本当はこの後Decoder部分もあります。簡単に説明しますと、DigitCaps層は各クラスの空間情報を持っていることから、その情報から各クラスの再構成します。その情報の取り出し方については、他のクラスをMaskする方法が原論文では提案されています。DeepCapsNetの提案論文ではそのクラス部分のみを取り出しています。
Decoder部分については割愛します。
参考