LoginSignup
5
3

More than 3 years have passed since last update.

Kerasのwrapperを使ってSpectral Normalizationを実装する

Last updated at Posted at 2020-03-25

はじめに

Tensorflowのバージョンは2.xです。
GANの安定化においてSpectral Normalizationは大きなブレークスルーとなっています。そんなSpectral Normalizationですが、後述する通り実装するにはカーネルの重みに細工をしなければならないので、少し実装が面倒くさいです。そこで今回はKerasのWrapperを使いSpectral Normalizationを実装します。

Spectral Normalizationとは

Spectral NormalizationはGANの安定化の手法で、GANのdiscriminatorをリプシッツ連続にして安定性を向上させるために、カーネルのテンソルをその最大特異値で割って正規化します。ここでの最大特異値はべき乗法で近似します。詳しくはググってください。

実装

import tensorflow as tf
import tensorflow.python.keras.backend as K
from tensorflow.keras.layers import *

class SpectralNormalization(Wrapper):
    def __init__(self, layer, **kwargs):
        super(SpectralNormalization, self).__init__(layer, **kwargs)

    def build(self, input_shape):
        if not self.layer.built:
            self.layer.build(input_shape)
            self.w = self.layer.kernel
            self.u = tf.Variable(
            tf.random.normal((tuple([1, self.layer.kernel.shape.as_list()[-1]])), dtype=tf.float32), 
            aggregation=tf.VariableAggregation.MEAN, trainable=False)
        super(SpectralNormalization, self).build()

    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v ** 2) ** 0.5 + eps)
        def power_iteration(W, u):
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v
        w_shape = self.w.shape.as_list()
        w_reshaped = K.reshape(self.w, [-1, w_shape[-1]])
        _u, _v = power_iteration(w_reshaped, self.u)
        sigma = K.dot(_v, w_reshaped)
        sigma = K.dot(sigma, K.transpose(_u))
        w_bar = w_reshaped / sigma
        if training == False:
            w_bar = K.reshape(w_bar, w_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                 w_bar = K.reshape(w_bar, w_shape) 
        output = self.layer(inputs)
        return output

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(
            self.layer.compute_output_shape(input_shape).as_list())

これを用いれば比較的簡単に以下のようにSpectral Normalizationを実装できます。

class ConvSN2D(Layer):
    def __init__(self, filters, kernel_size, strides=(1, 1), padding='valid', 
                 data_format=None, dilation_rate=(1, 1), activation=None, 
                 use_bias=True, kernel_initializer='glorot_uniform', 
                 bias_initializer='zeros', kernel_regularizer=None, 
                 bias_regularizer=None, activity_regularizer=None, 
                 kernel_constraint=None, bias_constraint=None, **kwargs):
        super(ConvSN2D, self).__init__()
        self.conv2d = Conv2D(
            filters, kernel_size, strides, padding, data_format, 
            dilation_rate, activation, use_bias, 
            kernel_initializer, bias_initializer, 
            kernel_regularizer, bias_regularizer, activity_regularizer, 
            kernel_constraint, bias_constraint, **kwargs)
        self.convsn2d = SpectralNormalization(self.conv2d)

    def call(self, inputs):
        return self.convsn2d(inputs)

ここでConv2Dを他のLayerに変えればそのLayerのSpectral Normalizationバージョンが作れます。

コードについての注意

べき乗法はpower_iterationという関数で実装されています。
self.uaggregation=tf.VariableAggregation.MEANは分散学習で各デバイスに分散されたself.uが戻ってくるときにどのような平均を取るという指示です(多分)。

if training == False:
        w_bar = K.reshape(w_bar, w_shape)
    else:
        with tf.control_dependencies([self.u.assign(_u)]):
             w_bar = K.reshape(w_bar, w_shape) 

の部分は、tf.control_dependencies([self.u.assign(_u)])等を外すと、tf.Graph関連のエラーが出てしまいます。詳細にはわかりませんが恐らくこれがないとself.uにGraphが通らないっぽいです。

追記 2020/04/14

kerasのレイヤーにはget_configという関数がないとmodel.saveでモデルを保存できない仕様があるので、上記のまま実装するとエラーがモデルを保存できません。
これを解決するにはget_config関数を追加すれば良いので、その方法を記します。
例えばConvSN2Dget_config関数を追加したいならば以下のように関数を追加すれば良いです。

def get_config(self):
    conf = self.conv2d.get_config()
    # Spectral Normalizationに特別な追加パラメタはないので、
    # configは名前だけ変更する
    conf['name'] = conf['name'].replace('conv', 'convsn')
    return conf

参考サイト

https://medium.com/@FloydHsiu0618/spectral-normalization-implementation-of-tensorflow-2-0-keras-api-d9060d26de77
https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py

5
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3