概要
CapsNetを用いたGAN:CapsGANをKerasで構築して学習させてみます。ここで構築するCapsGANは、Discriminator部分のみがCapsNetによって構築されたものとなります。
CapsNetについては私が以前書いた記事をお読みください。
構築
早速、CapsGANを構築していきます。
CapsNet
Discriminatorに用いるCapsNetを定義します。当記事では詳しく説明しません。
class Length(layers.Layer):
"""
Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
inputs: shape=[None, num_vectors, dim_vector]
output: shape=[None, num_vectors]
"""
def call(self, inputs, **kwargs):
return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon())
def compute_output_shape(self, input_shape):
return input_shape[:-1]
def get_config(self):
config = super(Length, self).get_config()
return config
class Mask(layers.Layer):
"""
Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional
input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
masked Tensor.
For example:
```
x = keras.layers.Input(shape=[8, 3, 2]) # batch_size=8, each sample contains 3 capsules with dim_vector=2
y = keras.layers.Input(shape=[8, 3]) # True labels. 8 samples, 3 classes, one-hot coding.
out = Mask()(x) # out.shape=[8, 6]
# or
out2 = Mask()([x, y]) # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
```
"""
def call(self, inputs, **kwargs):
if type(inputs) is list: # true label is provided with shape = [None, n_classes], i.e. one-hot code.
assert len(inputs) == 2
inputs, mask = inputs
else: # if no true label, mask by the max length of capsules. Mainly used for prediction
# compute lengths of capsules
x = K.sqrt(K.sum(K.square(inputs), -1))
# generate the mask which is a one-hot code.
# mask.shape=[None, n_classes]=[None, num_capsule]
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
# inputs.shape=[None, num_capsule, dim_capsule]
# mask.shape=[None, num_capsule]
# masked.shape=[None, num_capsule * dim_capsule]
masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
return masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple: # true label provided
return tuple([None, input_shape[0][1] * input_shape[0][2]])
else: # no true label provided
return tuple([None, input_shape[1] * input_shape[2]])
def get_config(self):
config = super(Mask, self).get_config()
return config
def squash(vectors, axis=-1):
"""
The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
:param vectors: some vectors to be squashed, N-dim tensor
:param axis: the axis to squash
:return: a Tensor with same shape as input vectors
"""
s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
return scale * vectors
class CapsuleLayer(layers.Layer):
"""
The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the
neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron
from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \
[None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1.
:param num_capsule: number of capsules in this layer
:param dim_capsule: dimension of the output vectors of the capsules in this layer
:param routings: number of iterations for the routing algorithm
"""
def __init__(self, num_capsule, dim_capsule, routings=3,
kernel_initializer='glorot_uniform',
**kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]"
self.input_num_capsule = input_shape[1]
self.input_dim_capsule = input_shape[2]
# Transform matrix
self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule,
self.dim_capsule, self.input_dim_capsule],
initializer=self.kernel_initializer,
name='W')
self.built = True
def call(self, inputs, training=None):
# inputs.shape=[None, input_num_capsule, input_dim_capsule]
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule]
inputs_expand = K.expand_dims(inputs, 1)
# Replicate num_capsule dimension to prepare being multiplied by W
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule]
inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1])
# Compute `inputs * W` by scanning inputs_tiled on dimension 0.
# x.shape=[num_capsule, input_num_capsule, input_dim_capsule]
# W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
# Regard the first two dimensions as `batch` dimension,
# then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled)
# Begin: Routing algorithm ---------------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
# b.shape = [None, self.num_capsule, self.input_num_capsule].
b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule])
assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
# c.shape=[batch_size, num_capsule, input_num_capsule]
c = tf.nn.softmax(b, dim=1)
# c.shape = [batch_size, num_capsule, input_num_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
# outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16]
if i < self.routings - 1:
# outputs.shape = [None, num_capsule, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule].
# b.shape=[batch_size, num_capsule, input_num_capsule]
b += K.batch_dot(outputs, inputs_hat, [2, 3])
# End: Routing algorithm -----------------------------------------------------------------------#
return outputs
def compute_output_shape(self, input_shape):
return tuple([None, self.num_capsule, self.dim_capsule])
def get_config(self):
config = {
'num_capsule': self.num_capsule,
'dim_capsule': self.dim_capsule,
'routings': self.routings
}
base_config = super(CapsuleLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
"""
Apply Conv2D `n_channels` times and concatenate all capsules
:param inputs: 4D tensor, shape=[None, width, height, channels]
:param dim_capsule: the dim of the output vector of capsule
:param n_channels: the number of types of capsules
:return: output tensor, shape=[None, num_capsule, dim_capsule]
"""
output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding,
name='primarycap_conv2d')(inputs)
outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output)
return layers.Lambda(squash, name='primarycap_squash')(outputs)
Discriminator
Discriminator(識別器)は、入力された画像が学習画像と生成された画像のどちらかを識別することを目的にしているDNNです。DCGANとは違い、CNNではなくCapsNetになっています。CapsNetは[0,1]のベクトルをクラス数出力します。ここではクラス数を1とし、出力値が0に近いほど偽物(生成画像)、1に近いほど本物(学習画像)であるとします。
以下の画像はMNISTに対するDiscriminatorの概略図です。
# Discriminatorの構造
def build_discriminator(shape, n_class, routings):
x = Input(shape=shape)
# Layer 1: Just a conventional Conv2D layer
conv1 = Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
# Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
# Layer 3: Capsule layer. Routing algorithm works here.
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
name='digitcaps')(primarycaps)
# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
# If using tensorflow, this will not be necessary. :)
out_caps = Length(name='capsnet')(digitcaps)
# Models for training (prediction)
train_model = models.Model([x], [out_caps])
return train_model
# 損失関数Margin Lossの定義
def margin_loss(y_true, y_pred):
"""
Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
return K.mean(K.sum(L, 1))
# Discriminatorのビルドとコンパイル
discriminator = build_discriminator(shape, 1, 3)
discriminator.compile(optimizer=Adam(0.0001, 0.3),
loss=[margin_loss],
loss_weights=[1.],
metrics={'capsnet': 'accuracy'})
Generator
Genetratorは学習画像によく似た画像を生成することを目的としています。Generatorは通常(DCGAN)のものと変わりません。
def build_generator(z_size=100):
z_shape = (z_size,)
x_noise = Input(shape=z_shape)
if (shape[0] == 28 and shape[1] == 28):
x = Dense(128 * 7 * 7, activation="relu")(x_noise)
x = Reshape((7, 7, 128))(x)
x = BatchNormalization(momentum=0.8)(x)
x = UpSampling2D()(x)
x = Conv2D(128, kernel_size=3, padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(momentum=0.8)(x)
x = UpSampling2D()(x)
x = Conv2D(64, kernel_size=3, padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Conv2D(1, kernel_size=3, padding="same")(x)
gen_out = Activation("tanh")(x)
return Model(x_noise, gen_out)
if (shape[0] == 32 and shape[1] == 32):
x = Dense(128 * 8 * 8, activation="relu")(x_noise)
x = Reshape((8, 8, 128))(x)
x = BatchNormalization(momentum=0.8)(x)
x = UpSampling2D()(x)
x = Conv2D(128, kernel_size=3, padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(momentum=0.8)(x)
x = UpSampling2D()(x)
x = Conv2D(64, kernel_size=3, padding="same")(x)
x = Activation("relu")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Conv2D(3, kernel_size=3, padding="same")(x)
gen_out = Activation("tanh")(x)
return Model(x_noise, gen_out)
# Generatorのビルドとコンパイル
z_size = 128
generator = build_generator(z_size)
generator.compile(loss='binary_crossentropy', optimizer=Adam(0.0003, 0.5))
GAN
上で作成したDiscriminatorとGeneratorを繋げて、GANを作成します。
z = Input(shape=(z_size,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0003, 0.5))
学習
それでは、上で作成したモデルを学習させていきます。
その前に、Generatorが作成した画像を描画して表示する関数を作ります。
def show_imgs(epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, z_size))
gen_imgs = generator.predict(noise)
# rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
# iterate in order to create a subplot
for i in range(r):
for j in range(c):
if dataset_title == 'mnist' or dataset_title == 'f_mnist':
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
elif dataset_title == 'cifar10':
axs[i,j].imshow(gen_imgs[cnt, :,:,:])
axs[i,j].axis('off')
cnt += 1
else:
print('Please indicate the image options.')
plt.suptitle(f'epoch: {epoch}')
plt.show()
plt.close()
次の関数が、学習するための関数になります。
# loss values for further plotting
D_L_REAL = []
D_L_FAKE = []
D_L = []
D_ACC = []
G_L = []
# 学習するための関数
def train(epochs, batch_size=32, show_interval=50, seed=42):
# np.randomのseed値を設定
np.random.seed(seed)
# DiscriminatorはGeneratorより早く学習が収束する
# 二つが競合する形が望ましい
# ハンデとして、Disriminatorは小さいバッチで学習させる
small_batch = int(batch_size / 10)
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# select a random half batch of images
idx = np.random.randint(0, X_train.shape[0], small_batch)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (small_batch, z_size))
# generate a half batch of new images
gen_imgs = generator.predict(noise)
y_ones = np.ones((small_batch, 1))
y_zeros = np.zeros((small_batch, 1))
y_real = np.concatenate([y_zeros, y_ones], axis=1)
y_fake = np.concatenate([y_ones, y_zeros], axis=1)
# train the discriminator by feeding both real and fake (generated) images one by one
d_loss_real = discriminator.train_on_batch(imgs, y_ones)
d_loss_fake = discriminator.train_on_batch(gen_imgs, y_zeros)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, z_size))
# the generator wants the discriminator to label the generated samples
# as valid (ones)
valid_y = np.array([1] * 32)
# train the generator
y_ones = np.ones((batch_size, 1))
y_zeros = np.zeros((batch_size, 1))
y_real = np.concatenate([y_zeros, y_ones], axis=1)
g_loss = combined.train_on_batch(noise, y_ones)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
D_L_REAL.append(d_loss_real)
D_L_FAKE.append(d_loss_fake)
D_L.append(d_loss)
D_ACC.append(d_loss[1])
G_L.append(g_loss)
# if at save interval => save generated image samples
if epoch % show_interval == 0:
show_imgs(epoch)
学習の実行は次のようにして行えます。
history = train(epochs=10000, batch_size=1024, show_interval=500)
結果
生成画像
学習が終わったGeneratorが生成した画像が次のものになります。
学習曲線
ソースコード