Python
DeepLearning
TensorFlow

[DL]weight decayって何?

deep neural networkではlayerが多層になるほど、そのモデルの表現能力が増します。
しかし、多層になるほどoverfittingのリスクも高くなります。
Modelの表現能力を維持したまま、parameterの自由度に制限を与えることでoverfittingのリスクを減らすことが行われます。
その手法の一つがweight decay(重み減衰)です。

weightの更新式は下記のように書かれます。

w \leftarrow w -\eta \frac{\partial C(w)}{\partial w} - \eta \lambda w

上記式は何がしたいのか少しわかりづらいですが、実際はcost functionを下記のようにしたものから来ています。

\tilde C(w) = C(w) + \frac{\lambda}{2}||w||^2

これはつまり、cost functionL2 regularization項をつけたものです。
この項によりweightの値は小さくなります。
なので、実際に実装するときはL2 regularizationの項をcostに加えることになります。

通常はbiasにはL2 regularizationを適応しません。
これはneuronのweightとbiasの役割の違いから来ています。
weightは入力を選択する役割なので、優先順位さえ変わらなければ値が小さくなっても構いませんが、
biasは閾値の役割のため値が大きくなければならないこともあるためです。

実際にweight decayありweight decayなしで学習させてweightのヒストグラムを見てみると下図のようになります。
左がweight decayなし、右がweight decayありです。
weightが小さくなっているのがわかると思います。

weightdecay1.png

accuracyは下記のようになりました。
青がweight decayなし、赤がweight decayあり、点線がtrainingデータ、実線がvalidationデータの結果です。
weightdecay2.png

image

import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

image_size = 28
n_labels = 10
n_batch  = 128
n_train  = 10000
beta = 0.001

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]

with tf.variable_scope("slp"):
    x  = tf.placeholder(tf.float32, shape=(n_batch, image_size*image_size))
    y_ = tf.placeholder(tf.float32, shape=(n_batch, n_labels))
    w0 = tf.get_variable("w0", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b0 = tf.get_variable("b0", [n_labels], initializer=tf.constant_initializer(0.0))

    w1 = tf.get_variable("w1", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b1 = tf.get_variable("b1", [n_labels], initializer=tf.constant_initializer(0.0))

    y0 = tf.matmul( x, w0 ) + b0
    y1 = tf.matmul( x, w1 ) + b1

valid_data = mnist.validation.images
valid_labels = mnist.validation.labels
test_data = mnist.test.images
test_labels = mnist.test.labels
vx = tf.constant(valid_data)
vy_ = tf.constant(valid_labels)
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

loss0 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y0, y_))
optimizer0 = tf.train.GradientDescentOptimizer(0.5).minimize(loss0)

loss1 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y1, y_) + beta * tf.nn.l2_loss(w1))
optimizer1 = tf.train.GradientDescentOptimizer(0.5).minimize(loss1)

train_prediction0 = tf.nn.softmax(y0)
valid_prediction0 = tf.nn.softmax(tf.matmul(vx, w0) + b0)
test_prediction0  = tf.nn.softmax(tf.matmul(tx, w0) + b0)

train_prediction1 = tf.nn.softmax(y1)
valid_prediction1 = tf.nn.softmax(tf.matmul(vx, w1) + b1)
test_prediction1  = tf.nn.softmax(tf.matmul(tx, w1) + b1)

sess = tf.InteractiveSession()
# sess = tf.Session()

init = tf.initialize_all_variables()
sess.run(init)
result_accuracy = []

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)    
    _, L0, tp0 = sess.run([optimizer0, loss0, train_prediction0], feed_dict={x: bx, y_: by})
    _, L1, tp1 = sess.run([optimizer1, loss1, train_prediction1], feed_dict={x: bx, y_: by})
    if step % 500 == 0:
        ac_wo_decay_train = accuracy(tp0, by)
        ac_wo_decay_valid = accuracy(valid_prediction0.eval(), valid_labels)
        ac_wt_decay_train = accuracy(tp1, by)
        ac_wt_decay_valid = accuracy(valid_prediction1.eval(), valid_labels)
        ac = {'step' : step, 'wo_decay' : {'training' : ac_wo_decay_train, 'validation' : ac_wo_decay_valid}, 'wt_decay' : {'training' : ac_wt_decay_train, 'validation' : ac_wt_decay_valid}}
        result_accuracy.append(ac)
        print "step = %d, train accuracy0: %.2f, validation accuracy0: %.2f, train accuracy1: %.2f, validation accuracy1: %.2f" % (step, ac_wo_decay_train, ac_wo_decay_valid, ac_wt_decay_train, ac_wt_decay_valid)

print "test accuracy0: %.2f" % accuracy(test_prediction0.eval(), test_labels)
print "test accuracy1: %.2f" % accuracy(test_prediction1.eval(), test_labels)

fig,axes = plt.subplots(ncols=2, figsize=(8,4))
axes[0].hist(w0.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[0].set_title('without weight decay')
axes[0].set_xlabel('weight')
axes[1].hist(w1.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[1].set_title('with weight decay')
axes[1].set_xlabel('weight')
fig.show()

tr_step = [ac['step'] for ac in result_accuracy]
ac_training_wo_decay = [ac['wo_decay']['training'] for ac in result_accuracy]
ac_training_wt_decay = [ac['wt_decay']['training'] for ac in result_accuracy]
ac_validation_wo_decay = [ac['wo_decay']['validation'] for ac in result_accuracy]
ac_validation_wt_decay = [ac['wt_decay']['validation'] for ac in result_accuracy]

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)

ax.plot(tr_step, ac_training_wo_decay, color='blue', linestyle='dashed')
ax.plot(tr_step, ac_training_wt_decay, color='red', linestyle='dashed')
ax.plot(tr_step, ac_validation_wo_decay, color='blue', linestyle='solid')
ax.plot(tr_step, ac_validation_wt_decay, color='red', linestyle='solid')
ax.set_title('accuracy')
ax.set_xlabel('step')
ax.set_ylabel('accuracy')
ax.grid(True)
ax.set_xlim((0, 10000))
ax.set_ylim((0, 100))
fig.show()