LoginSignup
13
8

More than 3 years have passed since last update.

styleGAN2 keras実装+顔イラストで学習

Posted at

1.はじめに

 styleGAN2をkerasで実装しました。高解像・高品質な画像が生成できることで有名なモデルです。すべてkerasのLayerとして実装しました。そのほうがモデルの構築や再利用時に楽ですので。

論文リンク
Analyzing and Improving the Image Quality of StyleGAN

使用ライブラリ
・Tensorflow==2.2.0
・numpy

2.詳細

コードはここに掲載しますが、近いうちにgithubにも上げようと思います。長いので一部折りたたんであります。

1. Wscale(EQUALIZED LEARNING RATE)

 これはProgressive Growing of GANs for Improved Quality, Stability, and Variationで提案されている手法で、上記論文の4.1で述べられています。
 重みの初期化方法についてで、初期値は$N(0, 1)$の標準正規分布で初期化し、計算時に定数倍しています。この定数は以下のように計算されます(He's initializer)。

c = \frac{gain}{\sqrt{fan\_in}}

$fan\_in$は入力ユニット数で$gain$は1.0を使っています。ただ、Heの初期化の標準偏差は$c=\sqrt{\frac{2}{fan\_in}}$なので$gain=\sqrt{2}$にしないと一致しませんが、著者実装を見ると1.0を使っているのでそうしてあります($\sqrt{2}$はどこへ?)。この操作によって、異なるダイナミックレンジを持つパラメータもすべて同じスピードで学習することが可能になります。これはstyleGAN2でもすべてのweightの初期化に使用されています。

コード
class BaseWscale(tf.keras.layers.Layer):
    def __init__(self, gain=1.0, lrmul=1.0, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super(BaseWscale, self).__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)

        self.gain = gain
        self.lrmul = lrmul

    def _get_wscale_weight(self, kernel_shape):
        fan_in = np.prod(kernel_shape[:-1])
        he_std = self.gain / np.sqrt(fan_in)
        init_std = 1.0 / self.lrmul
        multiplier = he_std * self.lrmul
        return tf.keras.initializers.RandomNormal(stddev=init_std), multiplier

    @abstractmethod
    def _get_kernel(self):
        pass

    def get_config(self):
        base_config = super(BaseWscale, self).get_config()
        config = dict(
            gain=self.gain,
            lrmul=self.lrmul,
        )
        return dict(list(base_config.items()) + list(config.items()))

class WscaleConv2D(BaseWscale):
    def __init__(self, k_size, filters, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
        super(WscaleConv2D, self).__init__(gain=gain, lrmul=lrmul, **kwargs)

        self.k_size = k_size
        self.filters = filters
        self.use_bias = use_bias

    def build(self, input_shape):
        k_shape = [self.k_size, self.k_size, input_shape[-1], self.filters]
        initializer, multiplier = self._get_wscale_weight(k_shape)
        self.multiplier = self.add_weight(name='multiplier', 
                                          shape=[],
                                          dtype=tf.float32,
                                          initializer=tf.keras.initializers.Constant(multiplier),
                                          trainable=False,
                                          aggregation=tf.VariableAggregation.MEAN)
        self.kernel = self.add_weight(name='kernel',
                                      shape=k_shape,
                                      dtype=tf.float32,
                                      initializer=initializer,
                                      trainable=True,
                                      aggregation=tf.VariableAggregation.MEAN)
        if self.use_bias:
            self.bias = self.add_weight(name='bias',
                                        shape=[self.filters,],
                                        dtype=tf.float32,
                                        initializer=tf.keras.initializers.Zeros(),
                                        trainable=True,
                                        aggregation=tf.VariableAggregation.MEAN)
        super(WscaleConv2D, self).build(input_shape)

    def _get_kernel(self):
        return self.kernel * self.multiplier

    def call(self, inputs, **kwargs):
        conv_kernel = self._get_kernel()
        x = tf.nn.conv2d(inputs, conv_kernel, [1, 1, 1, 1], padding='SAME')
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias * self.lrmul)
        return x

    def get_config(self):
        base_config = super(WscaleConv2D, self).get_config()
        config = dict(
            k_size=self.k_size,
            filters=self.filters,
            use_bias=self.use_bias,
        )
        return dict(list(base_config.items()) + list(config.items()))

class WscaleDense(BaseWscale):
    def __init__(self, units, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
        super(WscaleDense, self).__init__(gain=gain, lrmul=lrmul, **kwargs)

        self.units = units
        self.use_bias = use_bias


    def build(self, input_shape):
        k_shape = [input_shape[-1], self.units]
        initializer, multiplier = self._get_wscale_weight(k_shape)
        self.multiplier = self.add_weight(name='multiplier', 
                                          shape=[],
                                          dtype=tf.float32,
                                          initializer=tf.keras.initializers.Constant(multiplier),
                                          trainable=False,
                                          aggregation=tf.VariableAggregation.MEAN)
        self.kernel = self.add_weight(name='kernel',
                                      shape=k_shape,
                                      dtype=tf.float32,
                                      initializer=initializer,
                                      trainable=True,
                                      aggregation=tf.VariableAggregation.MEAN)
        if self.use_bias:
            self.bias = self.add_weight(name='bias',
                                        shape=[self.units,],
                                        dtype=tf.float32,
                                        initializer=tf.keras.initializers.Zeros(),
                                        trainable=True,
                                        aggregation=tf.VariableAggregation.MEAN)
        super(WscaleDense, self).build(input_shape)

    def _get_kernel(self):
        return self.kernel * self.multiplier

    def call(self, inputs, **kwargs):
        kernel = self._get_kernel()
        x = tf.matmul(inputs, kernel)
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias * self.lrmul)
        return x

    def get_config(self):
        base_config = super(WscaleDense, self).get_config()
        config = dict(
            units=self.units,
            use_bias=self.use_bias,
        )
        return dict(list(base_config.items()) + list(config.items()))

 

2.ModulatedConv2D

 ここがstyleGANとstyleGAN2の決定的に異なる箇所で、styleGAN2のメインの処理となっています。styleGANではstyleの適用をfeaturemapに対して行っていましたが、styleGAN2ではweightに対してstyleを適用します。詳細は論文の2章2節を参照。
 以下のように実装しました。

class ModulatedConv2D(WscaleConv2D):
    def __init__(self, k_size, filters, use_bias=True, demod=True, noise_addition=True, random_noise=True, gain=1.0, lrmul=1.0, **kwargs):
        super(ModulatedConv2D, self).__init__(k_size, filters, use_bias=use_bias, gain=gain, lrmul=lrmul, **kwargs)

        self.demod = demod
        self.noise_addition = noise_addition
        self.random_noise = random_noise

    def build(self, input_shape):
        self.style_projecter = WscaleDense(input_shape[0][-1])
        if self.noise_addition:
            self.noise_weight = self.add_weight(name='noise_weight',
                                                shape=[],
                                                dtype=tf.float32,
                                                initializer=tf.keras.initializers.Zeros(),
                                                trainable=True,
                                                aggregation=tf.VariableAggregation.MEAN)
            self.noise = self.add_weight(name='noise',
                                        shape=self._get_noise_shape(input_shape[0]),
                                        dtype=tf.float32,
                                        initializer=tf.keras.initializers.RandomNormal(stddev=1.0),
                                        trainable=False,
                                        aggregation=tf.VariableAggregation.MEAN)
        super(ModulatedConv2D, self).build(input_shape[0])

    def _get_noise_shape(self, input_shape):
        return [1, input_shape[1], input_shape[2], 1]

    def _get_kernel(self, style, training=None):
        kernel = super(ModulatedConv2D, self)._get_kernel()
        conv_kernel = kernel[np.newaxis] * style[:, np.newaxis, np.newaxis, :, np.newaxis]
        d =  tf.math.rsqrt(tf.reduce_sum(tf.square(conv_kernel), axis=[1, 2, 3]) + 1e-8)
        if training:
            if self.demod:
                conv_kernel *= d[:, np.newaxis, np.newaxis, np.newaxis, :]
            s = self.kernel.shape
            kernel = tf.reshape(tf.transpose(conv_kernel, [1, 2, 3, 0, 4]), [s[0], s[1], s[2], -1])
        return kernel, d

    def call(self, inputs, training=None, **kwargs):
        x, y = inputs[0], inputs[1]
        style = self.style_projecter(y)
        kernel, d = self._get_kernel(style=style, training=training)
        if training:
            x = tf.transpose(x, [1, 2, 0, 3])
            x = tf.reshape(x, [1, x.shape[0], x.shape[1], -1])
            x = tf.nn.conv2d(x, kernel, strides=1, padding='SAME')
            x = tf.transpose(tf.reshape(x, [x.shape[1], x.shape[2], -1, self.filters]), [2, 0, 1, 3])
        else:
            x = x * style[:, np.newaxis, np.newaxis, :]
            x = tf.nn.conv2d(x, kernel, strides=1, padding='SAME')
            if self.demod:
                x = x * d[:, np.newaxis, np.newaxis, :]
        if self.noise_addition:
            if self.random_noise:
                noise = tf.random.normal([tf.shape(x)[0], x.shape[1], x.shape[2], 1])
            else:
                noise = self.noise
            x = x + noise*self.noise_weight
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias * self.lrmul)
        return x

    def get_config(self):
        base_config = super(ModulatedConv2D, self).get_config()
        config = dict(
            demod=self.demod,
            noise_addition=self.noise_addition,
            random_noise=self.random_noise,
        )
        return dict(list(base_config.items()) + list(config.items()))

 styleGAN2では、weightに対してstyleを適用するためバッチごとに異なる重みを用いて畳み込みを行わなければなりません。こんなんどうやって実装するんだ?と思って著者実装を見ると、Group convolutionを使っていました(よく読んだら論文の4ページに書いてあった)。どうやら
 入力のチャンネル数 = カーネルの入力チャンネル数 * n
の時にn個のグループに分割して処理してくれる機能のようです(pytorchではnn.Conv2dのgroup引数で指定できるっぽい?)。ただこれはcuDNNの機能を使っているらしく(参考:Add support for cudnn's group convolution.)、GPU環境でないとエラーで動きません。学習に関してはGPU(今回はcolabのTPUを用いて行った)を使用すればよいと思いますが、推論はCPU環境でも行えるようにしたかったので、2種類の処理を実装しました。以下の図のcとdの処理です。学習時はdを、推論時はcを使うようにしてあります(callのtraining引数でフラグ管理)。両方の処理で同じ出力を得られることは確認してあります。

stylegan_fig2.PNG
Analyzing and Improving the Image Quality of StyleGANより引用

3.その他

1.Upsample, Downsample

 非公式の実装を見ているとシンプルにbilinearで画像の拡大処理を行っているものもあり、最初は自分もそのように実装していましたが、tf.image.resizeはTPUでのバックプロパゲーションに対応していないらしく(tensorflow 2.2.0時点)、

Lookup Error : Gradient Registry has no entry for : ~

のエラーが発生しました。なので、自分で実装しました。Upsampleは転置畳み込みを行ったのち、ブラーをかけています。Downsampleはブラーをかけたのち、ストライド付きの畳み込みを行っています。この処理は著者実装と全く同じ処理になっているはずです。

コード
class UpFIR2d(tf.keras.layers.Layer):
    def __init__(self, scale=2, k=None, gain=1.0, up=False, down=False, conv=False, conv_k_size=None, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super(UpFIR2d, self).__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)

        self.scale = scale
        self.gain = gain
        self.up = up
        self.down = down
        assert not(self.up == self.down)
        self.conv = conv
        self.conv_k_size = conv_k_size
        if k is None:
            self.k = (1,) * self.scale
        else:
            self.k = tuple(k)
        f = self._get_filter()
        self.filter = self.add_weight(name='resample_kernel',
                                      shape=f.shape,
                                      dtype=tf.float32,
                                      initializer=tf.keras.initializers.Constant(f),
                                      trainable=False,
                                      aggregation=tf.VariableAggregation.MEAN)

    def _get_filter(self):
        k = np.asarray(self.k, dtype=np.float32)
        k = np.outer(k, k)
        k /= np.sum(k)
        if self.up:
            k = k * (self.gain * (self.scale ** 2))
        elif self.down:
            k = k * self.gain
        return k[:,:,np.newaxis, np.newaxis]

    def _get_params(self):
        p = self.filter.shape[0] - self.scale
        if self.conv:
            if self.up:
                p -= (self.conv_k_size - 1)
            elif self.down:
                p += (self.conv_k_size - 1)
        if self.up:
            up = 1 if self.conv else self.scale
            down = 1
            p0 = (p+1)//2+self.scale-1
            p1 = p//2 + 1 if self.conv else p//2
        elif self.down:
            up = 1
            down = 1 if self.conv else self.scale
            p0 = (p+1)//2
            p1 = p//2
        return dict(up=up, down=down, p0=p0, p1=p1)

    def _upfirdn2d_op(self, x, upx, upy, downx, downy, px0, px1, py0, py1):
        xs = tf.shape(x)
        #x = tf.reshape(x, [-1, x.shape[1], 1, x.shape[2], 1, x.shape[3]])
        x = tf.reshape(x, [-1, xs[1], 1, xs[2], 1, xs[3]])
        x = tf.pad(x, [[0, 0], [0, 0], [0, upy-1], [0, 0], [0, upx-1], [0, 0]])
        x = tf.reshape(x, [-1, xs[1]*upy, xs[2]*upx, xs[3]])

        x = tf.pad(x, [[0, 0], [max(py0, 0), max(py1, 0)], [max(px0, 0), max(px1, 0)], [0, 0]])
        x = x[:, max(-py0, 0):tf.shape(x)[1] - max(-py1, 0), max(-px0, 0):tf.shape(x)[2] - max(-px1, 0), :]

        x = tf.nn.depthwise_conv2d(x, tf.tile(self.filter, [1, 1, tf.shape(x)[-1], 1]), strides=[1, 1, 1, 1], padding='VALID')
        return x[:, ::downy, ::downx, :]

    def _upfirdn2d(self, x, up=1, down=1, p0=0, p1=0):
        x = self._upfirdn2d_op(x, upx=up, upy=up, downx=down, downy=down, px0=p0, px1=p1, py0=p0, py1=p1)
        return x

    def call(self, inputs, **kwargs):
        params = self._get_params()
        x = self._upfirdn2d(inputs, **params)
        return x

    def get_config(self):
        basse_config = super(UpFIR2d, self).get_config()
        config = dict(
            scale=self.scale,
            gain=self.gain,
            up=self.up,
            down=self.down,
            conv=self.conv,
            conv_k_size=self.conv_k_size,
            k=self.k,
        )
        return dict(list(basse_config.items()) + list(config.items()))

class UpsampleConv(ModulatedConv2D):
    def __init__(self, k_size, filters, scale=2, k=None, use_bias=True, demod=True, gain=1.0, lrmul=1.0, **kwargs):
        super(UpsampleConv, self).__init__(k_size, filters, use_bias=use_bias, demod=demod, gain=gain, lrmul=lrmul, **kwargs)

        self.scale = scale
        self.k = k
        self.fir = UpFIR2d(scale, k, gain, up=True, conv=True, conv_k_size=self.k_size)

    def _get_convolution_params(self, inputs, kernel):
        input_shape = tf.shape(inputs)
        k_shape = kernel.shape
        k_shape = tf.shape(kernel)
        strides = [1, self.scale, self.scale, 1]
        output_shape=[input_shape[0], (input_shape[1]-1)*self.scale + self.k_size, (input_shape[2]-1)*self.scale + self.k_size, k_shape[3]]
        n = input_shape[3]//k_shape[2]
        kernel = tf.reshape(kernel, [k_shape[0], k_shape[1], k_shape[2], n, -1])
        kernel = tf.transpose(kernel[::-1,::-1], [0, 1, 4, 3, 2])
        kernel = tf.reshape(kernel, [k_shape[0], k_shape[1], -1, n*k_shape[2]])
        return kernel, output_shape, strides

    def _get_noise_shape(self, input_shape):
        return [1, input_shape[1]*self.scale, input_shape[2]*self.scale, 1]

    def call(self, inputs, training=None, **kwargs):
        x, y = inputs[0], inputs[1]
        style = self.style_projecter(y)
        kernel, d = self._get_kernel(style=style, training=training)
        if training:
            x = tf.transpose(x, [1, 2, 0, 3])
            x = tf.reshape(x, [1, tf.shape(x)[0], tf.shape(x)[1], -1])
        else:
            x = x * style[:, np.newaxis, np.newaxis, :]
        kernel, output_shape, strides = self._get_convolution_params(x, kernel)
        x = tf.nn.conv2d_transpose(x, kernel, output_shape, strides, padding='VALID')
        x = self.fir(x)
        if training:
            x = tf.transpose(tf.reshape(x, [tf.shape(x)[1], tf.shape(x)[2], -1, self.filters]), [2, 0, 1, 3])
        elif self.demod:
            x = x * d[:, np.newaxis, np.newaxis, :]
        x.set_shape(self.compute_output_shape(inputs[0].shape))
        if self.noise_addition:
            if self.random_noise:
                noise = tf.random.normal([tf.shape(x)[0], x.shape[1], x.shape[2], 1])
            else:
                noise = self.noise
            x = x + noise*self.noise_weight
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias * self.lrmul)
        return x

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1]*self.scale, input_shape[2]*self.scale, self.filters])

    def get_config(self):
        basse_config = super(UpsampleConv, self).get_config()
        config = dict(
            k=self.k,
            scale=self.scale,
        )
        return dict(list(basse_config.items()) + list(config.items()))

class DownsampleConv(WscaleConv2D):
    def __init__(self, k_size, filters, scale=2, k=None, use_bias=True, gain=1.0, lrmul=1.0, **kwargs):
        super(DownsampleConv, self).__init__(k_size, filters, use_bias=use_bias, gain=gain, lrmul=lrmul, **kwargs)

        self.scale = scale
        self.k = k
        self.fir = UpFIR2d(scale=scale, k=k, gain=gain, down=True, conv=True, conv_k_size=self.k_size)

    def call(self, inputs, **kwargs):
        x = self.fir(inputs)
        conv_kernel = self._get_kernel()
        x = tf.nn.conv2d(x, conv_kernel, [1, self.scale, self.scale, 1], padding='VALID')
        x.set_shape(tf.TensorShape([inputs.shape[0], inputs.shape[1]//self.scale, inputs.shape[2]//self.scale, self.filters]))
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias * self.lrmul)
        return x

    def get_config(self):
        basse_config = super(DownsampleConv, self).get_config()
        config = dict(
            k=self.k,
            scale=self.scale,
        )
        return dict(list(basse_config.items()) + list(config.items()))

2.Firstlayer

 styleGAN(2も)では、生成部分の入力は定数になっており学習対象のパラメータなので、それを定義しているLayerになります。

コード
class FirstLayer(ModulatedConv2D):
    def __init__(self, shape, k_size, filters, use_bias=True, demod=True, gain=1.0, lrmul=1.0, **kwargs):
        super(FirstLayer, self).__init__(k_size, filters, use_bias=use_bias, demod=demod, gain=gain, lrmul=lrmul, **kwargs)

        self.shape = tuple(shape)
        self.constant_inputs = self.add_weight(name='constant_input',
                                               shape=self.shape,
                                               dtype=tf.float32,
                                               initializer=tf.keras.initializers.RandomNormal(),
                                               trainable=True,
                                               aggregation=tf.VariableAggregation.MEAN)

    def build(self, input_shape):
        super(FirstLayer, self).build([self.shape, input_shape])

    def call(self, inputs, training=None, **kwargs):
        s = tf.shape(inputs)
        return super(FirstLayer, self).call([tf.tile(self.constant_inputs, [s[0], 1, 1, 1]), inputs], training=training, **kwargs)

    def get_config(self):
        base_config = super(FirstLayer, self).get_config()
        config = dict(shape=self.shape)
        return dict(list(base_config.items()) + list(config.items()))

3.ScaledLeakyRelu

 著者の実装では活性化関数を通した後に、定数倍($\sqrt{2}$倍)している箇所があります。論文中には

we scale our activation functions so that they retain the expected signal variance

という記述があるので出力の分散を維持するための操作でしょうか?一応これもLayerとして実装してあります。

コード
class ScaledLeakyReLU(tf.keras.layers.LeakyReLU):
    def __init__(self, alpha=0.2, act_gain=np.sqrt(2), **kwargs):
        super(ScaledLeakyReLU, self).__init__(alpha=alpha, **kwargs)

        self.act_gain = act_gain

    def call(self, inputs):
        return super(ScaledLeakyReLU, self).call(inputs) * self.act_gain

    def get_config(self):
        base_config = super(ScaledLeakyReLU, self).get_config()
        config = dict(act_gain=self.act_gain)
        return dict(list(base_config.items()) + list(config.items()))

3.学習経過

 colabのTPUで約24時間(インスタンスが2回落ちるまで)学習を行いました。パラメータの更新回数は127000回です。まだまだモデルが収束していないので歪な画像も多いです。128x128の画像です。

0125000.png

analogy

こちらは、style mappingネットワークに入力するノイズを変化させていった時の変化です。
analogy2d (1).png

style mixing

1列目の画像のstyle入力の一部を1行目の画像のstyleに差し替えて生成させたものです。2、3行目は4x4、8x8の解像度に入力を、4、5行目は16x16、32x32への入力を、6、7行目は64x64、128x128への入力をそれぞれ差し替えました。
mix_grid.png

13
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
8