はじめに
Tensorflowのバージョンは2.xです。
GANの安定化においてSpectral Normalizationは大きなブレークスルーとなっています。そんなSpectral Normalizationですが、後述する通り実装するにはカーネルの重みに細工をしなければならないので、少し実装が面倒くさいです。そこで今回はKerasのWrapperを使いSpectral Normalizationを実装します。
Spectral Normalizationとは
Spectral NormalizationはGANの安定化の手法で、GANのdiscriminatorをリプシッツ連続にして安定性を向上させるために、カーネルのテンソルをその最大特異値で割って正規化します。ここでの最大特異値はべき乗法で近似します。詳しくはググってください。
実装
import tensorflow as tf
import tensorflow.python.keras.backend as K
from tensorflow.keras.layers import *
class SpectralNormalization(Wrapper):
def __init__(self, layer, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
if not self.layer.built:
self.layer.build(input_shape)
self.w = self.layer.kernel
self.u = tf.Variable(
tf.random.normal((tuple([1, self.layer.kernel.shape.as_list()[-1]])), dtype=tf.float32),
aggregation=tf.VariableAggregation.MEAN, trainable=False)
super(SpectralNormalization, self).build()
def call(self, inputs, training=None):
def _l2normalize(v, eps=1e-12):
return v / (K.sum(v ** 2) ** 0.5 + eps)
def power_iteration(W, u):
_u = u
_v = _l2normalize(K.dot(_u, K.transpose(W)))
_u = _l2normalize(K.dot(_v, W))
return _u, _v
w_shape = self.w.shape.as_list()
w_reshaped = K.reshape(self.w, [-1, w_shape[-1]])
_u, _v = power_iteration(w_reshaped, self.u)
sigma = K.dot(_v, w_reshaped)
sigma = K.dot(sigma, K.transpose(_u))
w_bar = w_reshaped / sigma
if training == False:
w_bar = K.reshape(w_bar, w_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
w_bar = K.reshape(w_bar, w_shape)
output = self.layer(inputs)
return output
def compute_output_shape(self, input_shape):
return tf.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
これを用いれば比較的簡単に以下のようにSpectral Normalizationを実装できます。
class ConvSN2D(Layer):
def __init__(self, filters, kernel_size, strides=(1, 1), padding='valid',
data_format=None, dilation_rate=(1, 1), activation=None,
use_bias=True, kernel_initializer='glorot_uniform',
bias_initializer='zeros', kernel_regularizer=None,
bias_regularizer=None, activity_regularizer=None,
kernel_constraint=None, bias_constraint=None, **kwargs):
super(ConvSN2D, self).__init__()
self.conv2d = Conv2D(
filters, kernel_size, strides, padding, data_format,
dilation_rate, activation, use_bias,
kernel_initializer, bias_initializer,
kernel_regularizer, bias_regularizer, activity_regularizer,
kernel_constraint, bias_constraint, **kwargs)
self.convsn2d = SpectralNormalization(self.conv2d)
def call(self, inputs):
return self.convsn2d(inputs)
ここでConv2Dを他のLayerに変えればそのLayerのSpectral Normalizationバージョンが作れます。
コードについての注意
べき乗法はpower_iteration
という関数で実装されています。
self.u
のaggregation=tf.VariableAggregation.MEAN
は分散学習で各デバイスに分散されたself.u
が戻ってくるときにどのような平均を取るという指示です(多分)。
if training == False:
w_bar = K.reshape(w_bar, w_shape)
else:
with tf.control_dependencies([self.u.assign(_u)]):
w_bar = K.reshape(w_bar, w_shape)
の部分は、tf.control_dependencies([self.u.assign(_u)])
等を外すと、tf.Graph関連のエラーが出てしまいます。詳細にはわかりませんが恐らくこれがないとself.u
にGraphが通らないっぽいです。
#追記 2020/04/14
kerasのレイヤーにはget_config
という関数がないとmodel.save
でモデルを保存できない仕様があるので、上記のまま実装するとエラーがモデルを保存できません。
これを解決するにはget_config
関数を追加すれば良いので、その方法を記します。
例えばConvSN2D
にget_config
関数を追加したいならば以下のように関数を追加すれば良いです。
def get_config(self):
conf = self.conv2d.get_config()
# Spectral Normalizationに特別な追加パラメタはないので、
# configは名前だけ変更する
conf['name'] = conf['name'].replace('conv', 'convsn')
return conf
参考サイト
https://medium.com/@FloydHsiu0618/spectral-normalization-implementation-of-tensorflow-2-0-keras-api-d9060d26de77
https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py