#1.はじめに
styleGAN2をkerasで実装しました。高解像・高品質な画像が生成できることで有名なモデルです。すべてkerasのLayerとして実装しました。そのほうがモデルの構築や再利用時に楽ですので。
論文リンク
Analyzing and Improving the Image Quality of StyleGAN
使用ライブラリ
・Tensorflow==2.2.0
・numpy
#2.詳細
コードはここに掲載しますが、近いうちにgithubにも上げようと思います。長いので一部折りたたんであります。
###1. Wscale(EQUALIZED LEARNING RATE)
これはProgressive Growing of GANs for Improved Quality, Stability, and Variationで提案されている手法で、上記論文の4.1で述べられています。
重みの初期化方法についてで、初期値は$N(0, 1)$の標準正規分布で初期化し、計算時に定数倍しています。この定数は以下のように計算されます(He's initializer)。
c = \frac{gain}{\sqrt{fan\_in}}
$fan\_in$は入力ユニット数で$gain$は1.0を使っています。ただ、Heの初期化の標準偏差は$c=\sqrt{\frac{2}{fan\_in}}$なので$gain=\sqrt{2}$にしないと一致しませんが、著者実装を見ると1.0を使っているのでそうしてあります($\sqrt{2}$はどこへ?)。この操作によって、異なるダイナミックレンジを持つパラメータもすべて同じスピードで学習することが可能になります。これはstyleGAN2でもすべてのweightの初期化に使用されています。
コード
class BaseWscale(tf.keras.layers.Layer):
def __init__(self, gain=1.0, lrmul=1.0, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
super(BaseWscale, self).__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)
self.gain = gain
self.lrmul = lrmul
def _get_wscale_weight(self, kernel_shape):
fan_in = np.prod(kernel_shape[:-1])
he_std = self.gain / np.sqrt(fan_in)
init_std = 1.0 / self.lrmul
multiplier = he_std * self.lrmul
return tf.keras.initializers.RandomNormal(stddev=init_std), multiplier
@abstractmethod
def _get_kernel(self):
pass
def get_config(self):
base_config = super(BaseWscale, self).get_config()
config = dict(
gain=self.gain,
lrmul=self.lrmul,
)
return dict(list(base_config.items()) + list(config.items()))
class WscaleConv2D(BaseWscale):
def __init__(self, k_size, filters, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
super(WscaleConv2D, self).__init__(gain=gain, lrmul=lrmul, **kwargs)
self.k_size = k_size
self.filters = filters
self.use_bias = use_bias
def build(self, input_shape):
k_shape = [self.k_size, self.k_size, input_shape[-1], self.filters]
initializer, multiplier = self._get_wscale_weight(k_shape)
self.multiplier = self.add_weight(name='multiplier',
shape=[],
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(multiplier),
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
self.kernel = self.add_weight(name='kernel',
shape=k_shape,
dtype=tf.float32,
initializer=initializer,
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
if self.use_bias:
self.bias = self.add_weight(name='bias',
shape=[self.filters,],
dtype=tf.float32,
initializer=tf.keras.initializers.Zeros(),
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
super(WscaleConv2D, self).build(input_shape)
def _get_kernel(self):
return self.kernel * self.multiplier
def call(self, inputs, **kwargs):
conv_kernel = self._get_kernel()
x = tf.nn.conv2d(inputs, conv_kernel, [1, 1, 1, 1], padding='SAME')
if self.use_bias:
x = tf.nn.bias_add(x, self.bias * self.lrmul)
return x
def get_config(self):
base_config = super(WscaleConv2D, self).get_config()
config = dict(
k_size=self.k_size,
filters=self.filters,
use_bias=self.use_bias,
)
return dict(list(base_config.items()) + list(config.items()))
class WscaleDense(BaseWscale):
def __init__(self, units, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
super(WscaleDense, self).__init__(gain=gain, lrmul=lrmul, **kwargs)
self.units = units
self.use_bias = use_bias
def build(self, input_shape):
k_shape = [input_shape[-1], self.units]
initializer, multiplier = self._get_wscale_weight(k_shape)
self.multiplier = self.add_weight(name='multiplier',
shape=[],
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(multiplier),
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
self.kernel = self.add_weight(name='kernel',
shape=k_shape,
dtype=tf.float32,
initializer=initializer,
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
if self.use_bias:
self.bias = self.add_weight(name='bias',
shape=[self.units,],
dtype=tf.float32,
initializer=tf.keras.initializers.Zeros(),
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
super(WscaleDense, self).build(input_shape)
def _get_kernel(self):
return self.kernel * self.multiplier
def call(self, inputs, **kwargs):
kernel = self._get_kernel()
x = tf.matmul(inputs, kernel)
if self.use_bias:
x = tf.nn.bias_add(x, self.bias * self.lrmul)
return x
def get_config(self):
base_config = super(WscaleDense, self).get_config()
config = dict(
units=self.units,
use_bias=self.use_bias,
)
return dict(list(base_config.items()) + list(config.items()))
class ModulatedConv2D(WscaleConv2D):
def __init__(self, k_size, filters, use_bias=True, demod=True, noise_addition=True, random_noise=True, gain=1.0, lrmul=1.0, **kwargs):
super(ModulatedConv2D, self).__init__(k_size, filters, use_bias=use_bias, gain=gain, lrmul=lrmul, **kwargs)
self.demod = demod
self.noise_addition = noise_addition
self.random_noise = random_noise
def build(self, input_shape):
self.style_projecter = WscaleDense(input_shape[0][-1])
if self.noise_addition:
self.noise_weight = self.add_weight(name='noise_weight',
shape=[],
dtype=tf.float32,
initializer=tf.keras.initializers.Zeros(),
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
self.noise = self.add_weight(name='noise',
shape=self._get_noise_shape(input_shape[0]),
dtype=tf.float32,
initializer=tf.keras.initializers.RandomNormal(stddev=1.0),
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
super(ModulatedConv2D, self).build(input_shape[0])
def _get_noise_shape(self, input_shape):
return [1, input_shape[1], input_shape[2], 1]
def _get_kernel(self, style, training=None):
kernel = super(ModulatedConv2D, self)._get_kernel()
conv_kernel = kernel[np.newaxis] * style[:, np.newaxis, np.newaxis, :, np.newaxis]
d = tf.math.rsqrt(tf.reduce_sum(tf.square(conv_kernel), axis=[1, 2, 3]) + 1e-8)
if training:
if self.demod:
conv_kernel *= d[:, np.newaxis, np.newaxis, np.newaxis, :]
s = self.kernel.shape
kernel = tf.reshape(tf.transpose(conv_kernel, [1, 2, 3, 0, 4]), [s[0], s[1], s[2], -1])
return kernel, d
def call(self, inputs, training=None, **kwargs):
x, y = inputs[0], inputs[1]
style = self.style_projecter(y)
kernel, d = self._get_kernel(style=style, training=training)
if training:
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.reshape(x, [1, x.shape[0], x.shape[1], -1])
x = tf.nn.conv2d(x, kernel, strides=1, padding='SAME')
x = tf.transpose(tf.reshape(x, [x.shape[1], x.shape[2], -1, self.filters]), [2, 0, 1, 3])
else:
x = x * style[:, np.newaxis, np.newaxis, :]
x = tf.nn.conv2d(x, kernel, strides=1, padding='SAME')
if self.demod:
x = x * d[:, np.newaxis, np.newaxis, :]
if self.noise_addition:
if self.random_noise:
noise = tf.random.normal([tf.shape(x)[0], x.shape[1], x.shape[2], 1])
else:
noise = self.noise
x = x + noise*self.noise_weight
if self.use_bias:
x = tf.nn.bias_add(x, self.bias * self.lrmul)
return x
def get_config(self):
base_config = super(ModulatedConv2D, self).get_config()
config = dict(
demod=self.demod,
noise_addition=self.noise_addition,
random_noise=self.random_noise,
)
return dict(list(base_config.items()) + list(config.items()))
styleGAN2では、weightに対してstyleを適用するためバッチごとに異なる重みを用いて畳み込みを行わなければなりません。こんなんどうやって実装するんだ?と思って著者実装を見ると、Group convolutionを使っていました(よく読んだら論文の4ページに書いてあった)。どうやら
入力のチャンネル数 = カーネルの入力チャンネル数 * n
の時にn個のグループに分割して処理してくれる機能のようです(pytorchではnn.Conv2dのgroup引数で指定できるっぽい?)。ただこれはcuDNNの機能を使っているらしく(参考:Add support for cudnn's group convolution.)、GPU環境でないとエラーで動きません。学習に関してはGPU(今回はcolabのTPUを用いて行った)を使用すればよいと思いますが、推論はCPU環境でも行えるようにしたかったので、2種類の処理を実装しました。以下の図のcとdの処理です。学習時はdを、推論時はcを使うようにしてあります(callのtraining引数でフラグ管理)。両方の処理で同じ出力を得られることは確認してあります。
###3.その他
####1.Upsample, Downsample
非公式の実装を見ているとシンプルにbilinearで画像の拡大処理を行っているものもあり、最初は自分もそのように実装していましたが、tf.image.resizeはTPUでのバックプロパゲーションに対応していないらしく(tensorflow 2.2.0時点)、
Lookup Error : Gradient Registry has no entry for : ~
のエラーが発生しました。なので、自分で実装しました。Upsampleは転置畳み込みを行ったのち、ブラーをかけています。Downsampleはブラーをかけたのち、ストライド付きの畳み込みを行っています。この処理は著者実装と全く同じ処理になっているはずです。
コード
class UpFIR2d(tf.keras.layers.Layer):
def __init__(self, scale=2, k=None, gain=1.0, up=False, down=False, conv=False, conv_k_size=None, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
super(UpFIR2d, self).__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)
self.scale = scale
self.gain = gain
self.up = up
self.down = down
assert not(self.up == self.down)
self.conv = conv
self.conv_k_size = conv_k_size
if k is None:
self.k = (1,) * self.scale
else:
self.k = tuple(k)
f = self._get_filter()
self.filter = self.add_weight(name='resample_kernel',
shape=f.shape,
dtype=tf.float32,
initializer=tf.keras.initializers.Constant(f),
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
def _get_filter(self):
k = np.asarray(self.k, dtype=np.float32)
k = np.outer(k, k)
k /= np.sum(k)
if self.up:
k = k * (self.gain * (self.scale ** 2))
elif self.down:
k = k * self.gain
return k[:,:,np.newaxis, np.newaxis]
def _get_params(self):
p = self.filter.shape[0] - self.scale
if self.conv:
if self.up:
p -= (self.conv_k_size - 1)
elif self.down:
p += (self.conv_k_size - 1)
if self.up:
up = 1 if self.conv else self.scale
down = 1
p0 = (p+1)//2+self.scale-1
p1 = p//2 + 1 if self.conv else p//2
elif self.down:
up = 1
down = 1 if self.conv else self.scale
p0 = (p+1)//2
p1 = p//2
return dict(up=up, down=down, p0=p0, p1=p1)
def _upfirdn2d_op(self, x, upx, upy, downx, downy, px0, px1, py0, py1):
xs = tf.shape(x)
#x = tf.reshape(x, [-1, x.shape[1], 1, x.shape[2], 1, x.shape[3]])
x = tf.reshape(x, [-1, xs[1], 1, xs[2], 1, xs[3]])
x = tf.pad(x, [[0, 0], [0, 0], [0, upy-1], [0, 0], [0, upx-1], [0, 0]])
x = tf.reshape(x, [-1, xs[1]*upy, xs[2]*upx, xs[3]])
x = tf.pad(x, [[0, 0], [max(py0, 0), max(py1, 0)], [max(px0, 0), max(px1, 0)], [0, 0]])
x = x[:, max(-py0, 0):tf.shape(x)[1] - max(-py1, 0), max(-px0, 0):tf.shape(x)[2] - max(-px1, 0), :]
x = tf.nn.depthwise_conv2d(x, tf.tile(self.filter, [1, 1, tf.shape(x)[-1], 1]), strides=[1, 1, 1, 1], padding='VALID')
return x[:, ::downy, ::downx, :]
def _upfirdn2d(self, x, up=1, down=1, p0=0, p1=0):
x = self._upfirdn2d_op(x, upx=up, upy=up, downx=down, downy=down, px0=p0, px1=p1, py0=p0, py1=p1)
return x
def call(self, inputs, **kwargs):
params = self._get_params()
x = self._upfirdn2d(inputs, **params)
return x
def get_config(self):
basse_config = super(UpFIR2d, self).get_config()
config = dict(
scale=self.scale,
gain=self.gain,
up=self.up,
down=self.down,
conv=self.conv,
conv_k_size=self.conv_k_size,
k=self.k,
)
return dict(list(basse_config.items()) + list(config.items()))
class UpsampleConv(ModulatedConv2D):
def __init__(self, k_size, filters, scale=2, k=None, use_bias=True, demod=True, gain=1.0, lrmul=1.0, **kwargs):
super(UpsampleConv, self).__init__(k_size, filters, use_bias=use_bias, demod=demod, gain=gain, lrmul=lrmul, **kwargs)
self.scale = scale
self.k = k
self.fir = UpFIR2d(scale, k, gain, up=True, conv=True, conv_k_size=self.k_size)
def _get_convolution_params(self, inputs, kernel):
input_shape = tf.shape(inputs)
k_shape = kernel.shape
k_shape = tf.shape(kernel)
strides = [1, self.scale, self.scale, 1]
output_shape=[input_shape[0], (input_shape[1]-1)*self.scale + self.k_size, (input_shape[2]-1)*self.scale + self.k_size, k_shape[3]]
n = input_shape[3]//k_shape[2]
kernel = tf.reshape(kernel, [k_shape[0], k_shape[1], k_shape[2], n, -1])
kernel = tf.transpose(kernel[::-1,::-1], [0, 1, 4, 3, 2])
kernel = tf.reshape(kernel, [k_shape[0], k_shape[1], -1, n*k_shape[2]])
return kernel, output_shape, strides
def _get_noise_shape(self, input_shape):
return [1, input_shape[1]*self.scale, input_shape[2]*self.scale, 1]
def call(self, inputs, training=None, **kwargs):
x, y = inputs[0], inputs[1]
style = self.style_projecter(y)
kernel, d = self._get_kernel(style=style, training=training)
if training:
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.reshape(x, [1, tf.shape(x)[0], tf.shape(x)[1], -1])
else:
x = x * style[:, np.newaxis, np.newaxis, :]
kernel, output_shape, strides = self._get_convolution_params(x, kernel)
x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides, padding='VALID')
x = self.fir(x)
if training:
x = tf.transpose(tf.reshape(x, [tf.shape(x)[1], tf.shape(x)[2], -1, self.filters]), [2, 0, 1, 3])
elif self.demod:
x = x * d[:, np.newaxis, np.newaxis, :]
x.set_shape(self.compute_output_shape(inputs[0].shape))
if self.noise_addition:
if self.random_noise:
noise = tf.random.normal([tf.shape(x)[0], x.shape[1], x.shape[2], 1])
else:
noise = self.noise
x = x + noise*self.noise_weight
if self.use_bias:
x = tf.nn.bias_add(x, self.bias * self.lrmul)
return x
def compute_output_shape(self, input_shape):
return tf.TensorShape([input_shape[0], input_shape[1]*self.scale, input_shape[2]*self.scale, self.filters])
def get_config(self):
basse_config = super(UpsampleConv, self).get_config()
config = dict(
k=self.k,
scale=self.scale,
)
return dict(list(basse_config.items()) + list(config.items()))
class DownsampleConv(WscaleConv2D):
def __init__(self, k_size, filters, scale=2, k=None, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
super(DownsampleConv, self).__init__(k_size, filters, use_bias=use_bias, gain=gain, lrmul=lrmul, **kwargs)
self.scale = scale
self.k = k
self.fir = UpFIR2d(scale=scale, k=k, gain=gain, down=True, conv=True, conv_k_size=self.k_size)
def call(self, inputs, **kwargs):
x = self.fir(inputs)
conv_kernel = self._get_kernel()
x = tf.nn.conv2d(x, conv_kernel, [1, self.scale, self.scale, 1], padding='VALID')
x.set_shape(tf.TensorShape([inputs.shape[0], inputs.shape[1]//self.scale, inputs.shape[2]//self.scale, self.filters]))
if self.use_bias:
x = tf.nn.bias_add(x, self.bias * self.lrmul)
return x
def get_config(self):
basse_config = super(DownsampleConv, self).get_config()
config = dict(
k=self.k,
scale=self.scale,
)
return dict(list(basse_config.items()) + list(config.items()))
####2.Firstlayer
styleGAN(2も)では、生成部分の入力は定数になっており学習対象のパラメータなので、それを定義しているLayerになります。
コード
class FirstLayer(ModulatedConv2D):
def __init__(self, shape, k_size, filters, use_bias=True, demod=True, gain=1.0, lrmul=1.0, **kwargs):
super(FirstLayer, self).__init__(k_size, filters, use_bias=use_bias, demod=demod, gain=gain, lrmul=lrmul, **kwargs)
self.shape = tuple(shape)
self.constant_inputs = self.add_weight(name='constant_input',
shape=self.shape,
dtype=tf.float32,
initializer=tf.keras.initializers.RandomNormal(),
trainable=True,
aggregation=tf.VariableAggregation.MEAN)
def build(self, input_shape):
super(FirstLayer, self).build([self.shape, input_shape])
def call(self, inputs, training=None, **kwargs):
s = tf.shape(inputs)
return super(FirstLayer, self).call([tf.tile(self.constant_inputs, [s[0], 1, 1, 1]), inputs], training=training, **kwargs)
def get_config(self):
base_config = super(FirstLayer, self).get_config()
config = dict(shape=self.shape)
return dict(list(base_config.items()) + list(config.items()))
####3.ScaledLeakyRelu
著者の実装では活性化関数を通した後に、定数倍($\sqrt{2}$倍)している箇所があります。論文中には
we scale our activation functions so that they retain the expected signal variance
という記述があるので出力の分散を維持するための操作でしょうか?一応これもLayerとして実装してあります。
コード
class ScaledLeakyReLU(tf.keras.layers.LeakyReLU):
def __init__(self, alpha=0.2, act_gain=np.sqrt(2), **kwargs):
super(ScaledLeakyReLU, self).__init__(alpha=alpha, **kwargs)
self.act_gain = act_gain
def call(self, inputs):
return super(ScaledLeakyReLU, self).call(inputs) * self.act_gain
def get_config(self):
base_config = super(ScaledLeakyReLU, self).get_config()
config = dict(act_gain=self.act_gain)
return dict(list(base_config.items()) + list(config.items()))
#3.学習経過
colabのTPUで約24時間(インスタンスが2回落ちるまで)学習を行いました。パラメータの更新回数は127000回です。まだまだモデルが収束していないので歪な画像も多いです。128x128の画像です。
##analogy
こちらは、style mappingネットワークに入力するノイズを変化させていった時の変化です。
##style mixing
1列目の画像のstyle入力の一部を1行目の画像のstyleに差し替えて生成させたものです。2、3行目は4x4、8x8の解像度に入力を、4、5行目は16x16、32x32への入力を、6、7行目は64x64、128x128への入力をそれぞれ差し替えました。