LoginSignup
2
4

More than 5 years have passed since last update.

Gated CNN の実装に関するメモ

Last updated at Posted at 2018-05-31

ポイント

  • Gated CNN を実装し、MNIST 手書き数字データでパフォーマンスを検証。
  • 4パターン( GLU、GTU、ReLU、Tanh ) を比較。
  • 今後、Sequential MNIST、Permuted MNIST で追加検証。

レファレンス

1. Language Modeling with Gated Convolutional Networks
2. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling

検証方法

  • MNIST手書き数字データを使用。28x28。
  • Residual Block、Regularization、Gradient Clipping を適用せずに、4パターンを比較。

データ

MNIST handwritten digits

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('***/mnist', \
                                     one_hot = True)

検証結果

数値計算例:

  • n_in = 28*28
  • filter_size = 5
  • n_units = 32
  • n_out = 10
  • n_layers = 7
  • learning_rate = 0.01
  • batch_size = 64

GLU
image.png

GTU
image.png

ReLU
image.png
< br>

tanh
image.png

サンプルコード

  def gated_conv(self, x, shape):
    w = self.weight_variable('w', shape)
    b = self.bias_variable('b', shape[-1])
    v = self.weight_variable('v', shape)
    c = self.bias_variable('c', shape[-1])

    f = tf.add(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], \
                                padding = 'VALID'), b)
    g = tf.add(tf.nn.conv2d(x, v, strides = [1, 1, 1, 1], \
                                padding = 'VALID'), c)

    output1 = tf.multiply(f, tf.sigmoid(g))  # GLU
    output2 = tf.multiply(tf.tanh(f), tf.sigmoid(g))  # GTU
    output3 = tf.nn.relu(f)  # ReLU
    output4 = tf.tanh(f)  # Tanh

    return output4

  # without residual block
  def inference(self, x, filter_size, n_in, n_units, \
                     n_out, n_layers):
    width = np.sqrt(n_in).astype(np.int32)
    channel = np.sqrt(n_in).astype(np.int32)
    x = tf.reshape(x, [-1, 1, width, channel])

    shape = [1, 1, channel, n_units]
    with tf.variable_scope('initial'):
      y = self.conv(x, shape)

    shape = [1, filter_size, n_units, n_units]
    for i in range(n_layers):
      with tf.variable_scope('layer_{}'.format(i + 1)):
        y = tf.pad(y, [[0, 0], [0, 0], \
                [filter_size - 1, 0], [0, 0]])
        y = self.gated_conv(y, shape) 

    y = y[:, :, -1, :]
    y = tf.squeeze(y, axis = 1)

    with tf.variable_scope('final'):
      w = self.weight_variable('w', [n_units, n_out])
      b = self.bias_variable('b', [n_out])

      y = tf.add(tf.matmul(y, w), b)
      y = tf.nn.softmax(y, axis = 1)

    return y

  def loss(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum( \
      t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  # without gradient clipping
  def training(self, loss, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate = \
                       learning_rate)
    train_step = optimizer.minimize(loss)
    return train_step

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), \
                          tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, \
                          tf.float32))
    return accuracy

  def fit(self, images_train, labels_train, images_test, \
             labels_Test, filter_size, n_in, n_units, \
             n_out, n_layers, learning_rate, n_iter, \
             batch_size, show_step, is_saving, model_path):

    tf.reset_default_graph()

    x = tf.placeholder(shape = [None, n_in], \
                                dtype = tf.float32)
    t = tf.placeholder(shape = [None, n_out], \
                                dtype = tf.float32)

    # without residual block
    y = self.inference(x, filter_size, n_in, n_units, \
                         n_out, n_layers)
    loss = self.loss(y, t)
    train_step = self.training(loss, learning_rate)
    acc =  self.accuracy(y, t)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)

      history_loss_train = []
      history_acc_train = []
      history_loss_test = []
      history_acc_test = []

      for i in range(n_iter):
        # Train
        rand_index = np.random.choice(len(images_train), \
                            size = batch_size)
        x_batch = images_train[rand_index]
        y_batch = labels_train[rand_index]

        feed_dict = {x: x_batch, t: y_batch}

        sess.run(train_step, feed_dict = feed_dict)

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_train.append(temp_loss)
        history_acc_train.append(temp_acc)

        if (i + 1) % show_step == 0:
          print ('--------------------')
          print ('Iteration: ' + str(i + 1) + '  Loss: ' \
               + str(temp_loss) + '  Accuracy: ' \
               + str(temp_acc))


        # Test
        rand_index = np.random.choice(len(images_test), \
                         size = batch_size)
        x_batch = images_test[rand_index]
        y_batch = labels_test[rand_index]

        feed_dict = {x: x_batch, t: y_batch}

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_test.append(temp_loss)
        history_acc_test.append(temp_acc)

      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

    fig = plt.figure(figsize = (10, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.plot(range(n_iter), history_loss_train, \
                            'b-', label = 'Train')
    ax1.plot(range(n_iter), history_loss_test, \
                            'r--', label = 'Test')
    ax1.set_title('Loss')
    ax1.legend(loc = 'upper right')

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.plot(range(n_iter), history_acc_train, \
                            'b-', label = 'Train')
    ax2.plot(range(n_iter), history_acc_test, \
                            'r--', label = 'Test')
    ax2.set_title('Accuracy')
    ax2.legend(loc = 'lower right')

    plt.show()

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4