LoginSignup
0
0

More than 5 years have passed since last update.

Spatial Transformer Network の実装に関するメモ

Posted at

ポイント

  • Spatial Transformer Network を実装し、具体的な数値で確認。
  • 今後、パフォーマンス検証を実施。

レファレンス

1. Spatial Transformer Networks

image.png


image.png

             (参照論文より引用)

検証方法

  • Distorted MNIST を使用。
  • Spatial Transformer ある、なしで Loss と Accuracy を比較。

データ

Distorted MNIST 作成に関するメモ

検証結果

n_loc_fc = 20
filter_size = 3
n_filters_1 = 16
n_filters_2 = 16
n_fc = 1024
learning_rate = 0.001
batch_size = 64

  1. Spatial Transformer なし
    image.png


    image.png

  2. Spatial Transformer あり
    image.png


    image.png

サンプルコード

class SpatialTransformer():
  def __init__(self):
    pass

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def localization_net(self, x, n_units, keep_prob):
    w_1 = self.weight_variable('w_1', [28 * 28, n_units])
    b_1 = self.bias_variable('b_1', [n_units])

    loc_1 = tf.nn.tanh(tf.matmul(x, w_1) + b_1)
    loc_1_dropout = tf.nn.dropout(loc_1, keep_prob)

    w_2 = self.weight_variable('w_2', [n_units, 6])
    value = np.array([[1, 0, 0], [0, 1, 0]]).astype(np.float32)
    value = value.flatten()
    init = tf.constant_initializer(value = value, dtype = tf.float32)
    b_2 = tf.get_variable('b_2', shape = [6], initializer = init)

    loc_2 = tf.nn.tanh(tf.matmul(loc_1_dropout, w_2) + b_2)

    return loc_2

  def grid_generator(self, batch_size, hight, width, theta):
    x_t = -1 + tf.range(width) / width * 2
    x_t = tf.reshape(tf.concat([x_t] * hight, axis = 0), [1, hight * width])
    x_t = tf.expand_dims(x_t, axis = 0)
    x_t = tf.tile(x_t, [batch_size, 1, 1])
    x_t = tf.cast(x_t, dtype = tf.float32)

    y_t = -1 + tf.range(hight) / hight * 2
    y_t = tf.tile(tf.reshape(y_t, [hight, 1]), [1, width])
    y_t = tf.reshape(y_t, [1, hight * width])
    y_t = tf.expand_dims(y_t, axis = 0)
    y_t = tf.tile(y_t, [batch_size, 1, 1])
    y_t = tf.cast(y_t, dtype = tf.float32)

    ones_t = tf.ones(shape = [batch_size, 1, hight * width], dtype = tf.float32)

    grids_t = tf.concat([x_t, y_t, ones_t], axis = 1)

    theta = tf.reshape(theta, [batch_size, 2, 3])
    last_row = tf.constant(value = [[[0.0, 0.0, 1.0]]], dtype = tf.float32)
    last_row = tf.tile(last_row, [batch_size, 1, 1])
    A = tf.concat([theta, last_row], axis = 1)

    grids_s = tf.matmul(A, grids_t)
    grids_s = tf.concat([grids_s[:, :1, :], grids_s[:, 1:2, :]], axis = 1)

    return grids_s

  def sampler(self, u, grids, batch_size, hight, width):
    img_in = tf.reshape(u, [-1, hight, width])

    x = (1.0 + grids[:, 0]) * width / 2.0
    y = (1.0 + grids[:, 1]) * hight / 2.0

    x0 = tf.cast(tf.floor(x), tf.int32) 
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), tf.int32) 
    y1 = y0 + 1

    w_ul = (tf.cast(x1, tf.float32) - x) * (tf.cast(y1, tf.float32) - y)
    w_ll = (tf.cast(x1, tf.float32) - x) * (y - tf.cast(y0, tf.float32))
    w_ur = (x - tf.cast(x0, tf.float32)) * (tf.cast(y1, tf.float32) - y)
    w_lr = (x - tf.cast(x0, tf.float32)) * (y - tf.cast(y0, tf.float32))

    x0_img = tf.maximum(0, tf.minimum(x0, width - 1))
    x1_img = tf.maximum(0, tf.minimum(x1, width - 1))
    y0_img = tf.maximum(0, tf.minimum(y0, hight - 1))
    y1_img = tf.maximum(0, tf.minimum(y1, hight - 1))

    x0_img = tf.expand_dims(x0_img, axis = 2)
    x1_img = tf.expand_dims(x1_img, axis = 2)
    y0_img = tf.expand_dims(y0_img, axis = 2)
    y1_img = tf.expand_dims(y1_img, axis = 2)

    idx_ul = tf.concat([y0_img, x0_img], axis = 2)
    idx_ll = tf.concat([y1_img, x0_img], axis = 2)
    idx_ur = tf.concat([y0_img, x1_img], axis = 2)
    idx_lr = tf.concat([y1_img, x1_img], axis = 2)

    idx_ul = tf.reshape(idx_ul, [batch_size * 28 * 28, 2])
    idx_ll = tf.reshape(idx_ll, [batch_size * 28 * 28, 2])
    idx_ur = tf.reshape(idx_ur, [batch_size * 28 * 28, 2])
    idx_lr = tf.reshape(idx_lr, [batch_size * 28 * 28, 2])

    idx_batch = np.reshape(np.arange(batch_size), [batch_size, 1])
    idx_batch = np.repeat(idx_batch, 28 * 28, axis = 0)
    idx_batch = tf.convert_to_tensor(idx_batch, dtype = tf.int32)

    idx_ul = tf.concat([idx_batch, idx_ul], axis = 1)
    idx_ll = tf.concat([idx_batch, idx_ll], axis = 1)
    idx_ur = tf.concat([idx_batch, idx_ur], axis = 1)
    idx_lr = tf.concat([idx_batch, idx_lr], axis = 1)

    img_ul = tf.gather_nd(img_in, idx_ul)
    img_ll = tf.gather_nd(img_in, idx_ll)
    img_ur = tf.gather_nd(img_in, idx_ur)
    img_lr = tf.gather_nd(img_in, idx_lr)

    img_ul = tf.reshape(img_ul, [batch_size, 28 * 28])
    img_ll = tf.reshape(img_ll, [batch_size, 28 * 28])
    img_ur = tf.reshape(img_ur, [batch_size, 28 * 28])
    img_lr = tf.reshape(img_lr, [batch_size, 28 * 28])

    v = w_ul * img_ul + w_ll * img_ll + w_ur * img_ur + w_lr * img_lr

    return v

  # With Spatial Transformer
  def inference(self, u, hight, width, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size):

    with tf.variable_scope('loc'):
      theta = self.localization_net(u, n_loc_fc, keep_prob)    
      grids_s = self.grid_generator(batch_size, hight, width, theta)
      v = self.sampler(u, grids_s, batch_size, hight, width)
      v_reshaped = tf.reshape(v, [-1, 28, 28, 1])

    with tf.variable_scope('conv_1'):
      w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
      b = self.bias_variable('b', [n_filters_1])

      # no max_pooling
      conv_1 = tf.nn.relu(tf.nn.conv2d(v_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    with tf.variable_scope('conv_2'):
      w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
      b = self.bias_variable('b', [n_filters_2])

      # no max_pooling
      conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])

    with tf.variable_scope('fc_1'):
      w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
      b = self.bias_variable('b', [n_fc])

      fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)

    fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)

    with tf.variable_scope('fc_2'):
      w = self.weight_variable('w', [n_fc, 10])
      b = self.bias_variable('b', [10])

      fc_2 = tf.matmul(fc_1_dropout, w) + b

    output = tf.nn.softmax(fc_2, axis = 1)

    return output

  # Without Spatial Transformer
  def inference_2(self, x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob):
    x_reshaped = tf.reshape(x, [-1, 28, 28, 1])

    with tf.variable_scope('conv_1'):
      w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
      b = self.bias_variable('b', [n_filters_1])

      # no max_pooling
      conv_1 = tf.nn.relu(tf.nn.conv2d(x_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    with tf.variable_scope('conv_2'):
      w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
      b = self.bias_variable('b', [n_filters_2])

      # no max_pooling
      conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])

    with tf.variable_scope('fc_1'):
      w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
      b = self.bias_variable('b', [n_fc])

      fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)

    fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)

    with tf.variable_scope('fc_2'):
      w = self.weight_variable('w', [n_fc, 10])
      b = self.bias_variable('b', [10])

      fc_2 = tf.matmul(fc_1_dropout, w) + b

    output = tf.nn.softmax(fc_2, axis = 1)

    return output

  def loss(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, images_train, labels_train, images_test, labels_Test, \
          n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, \
          learning_rate, n_iter, batch_size, show_step, is_saving, model_path):

    tf.reset_default_graph()

    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    t = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    keep_prob = tf.placeholder(shape = (), dtype = tf.float32)

    # With Spatial Transformer
    #y = self.inference(x, 28, 28, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size)

    # Without Spatial Transformer
    y = self.inference_2(x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob)

    loss = self.loss(y, t)

    # Without Gradient Clipping
    train_step = self.training(loss, learning_rate)
    # With Gradient Clipping
    #train_step = self.training_clipped(loss, learning_rate, 0.1)

    acc =  self.accuracy(y, t)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)

      history_loss_train = []
      history_acc_train = []
      history_loss_test = []
      history_acc_test = []

      for i in range(n_iter):
        # Train
        rand_index = np.random.choice(len(images_train), size = batch_size)
        x_batch = images_train[rand_index]
        y_batch = labels_train[rand_index]

        feed_dict = {x: x_batch, t: y_batch, keep_prob: 0.7}

        sess.run(train_step, feed_dict = feed_dict)

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_train.append(temp_loss)
        history_acc_train.append(temp_acc)

        if (i + 1) % show_step == 0:
          print ('--------------------')
          print ('Iteration: ' + str(i + 1) + '  Loss: ' + str(temp_loss) + \
                '  Accuracy: ' + str(temp_acc))

        # Test
        rand_index = np.random.choice(len(images_test), size = batch_size)
        x_batch = images_test[rand_index]
        y_batch = labels_test[rand_index]

        feed_dict = {x: x_batch, t: y_batch, keep_prob: 1.0}

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_test.append(temp_loss)
        history_acc_test.append(temp_acc)

      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

    fig = plt.figure(figsize = (10, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
    ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
    ax1.set_title('Loss')
    ax1.legend(loc = 'upper right')

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
    ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
    ax2.set_title('Accuracy')
    ax2.legend(loc = 'lower right')

    plt.show()

  def check_results(self, u, batch_size, hight, width, n_loc_fc, model_path):    

    with tf.variable_scope('loc', reuse = True):
      theta = self.localization_net(u, n_loc_fc, 1.0)    
      grids_s = self.grid_generator(batch_size, hight, width, theta)
      v = self.sampler(u, grids_s, batch_size, hight, width)
      v_reshaped = tf.reshape(v, [-1, 28, 28, 1])

    saver = tf.train.Saver()

    with tf.Session() as sess:

      saver.restore(sess, model_path)

      ret_theta = sess.run(theta)
      ret_v = sess.run(v)

    return ret_theta, ret_v
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0