0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Spatial Transformer Network の実装に関するメモ

Posted at

ポイント

  • Spatial Transformer Network を実装し、具体的な数値で確認。
  • 今後、パフォーマンス検証を実施。

レファレンス

1. Spatial Transformer Networks

image.png


image.png

             (参照論文より引用)

検証方法

  • Distorted MNIST を使用。
  • Spatial Transformer ある、なしで Loss と Accuracy を比較。

データ

Distorted MNIST 作成に関するメモ

検証結果

n_loc_fc = 20
filter_size = 3
n_filters_1 = 16
n_filters_2 = 16
n_fc = 1024
learning_rate = 0.001
batch_size = 64

  1. Spatial Transformer なし
    image.png

![image.png](https://qiita-image-store.s3.amazonaws.com/0/254604/c1b90fcf-b5f5-548f-888e-ee2ec2562d48.png)
  1. Spatial Transformer あり
    image.png

![image.png](https://qiita-image-store.s3.amazonaws.com/0/254604/25b30e30-8996-6f4e-58a1-cccd80b212dc.png)

サンプルコード

class SpatialTransformer():
  def __init__(self):
    pass
  
  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def localization_net(self, x, n_units, keep_prob):
    w_1 = self.weight_variable('w_1', [28 * 28, n_units])
    b_1 = self.bias_variable('b_1', [n_units])

    loc_1 = tf.nn.tanh(tf.matmul(x, w_1) + b_1)
    loc_1_dropout = tf.nn.dropout(loc_1, keep_prob)

    w_2 = self.weight_variable('w_2', [n_units, 6])
    value = np.array([[1, 0, 0], [0, 1, 0]]).astype(np.float32)
    value = value.flatten()
    init = tf.constant_initializer(value = value, dtype = tf.float32)
    b_2 = tf.get_variable('b_2', shape = [6], initializer = init)

    loc_2 = tf.nn.tanh(tf.matmul(loc_1_dropout, w_2) + b_2)

    return loc_2
        
  def grid_generator(self, batch_size, hight, width, theta):
    x_t = -1 + tf.range(width) / width * 2
    x_t = tf.reshape(tf.concat([x_t] * hight, axis = 0), [1, hight * width])
    x_t = tf.expand_dims(x_t, axis = 0)
    x_t = tf.tile(x_t, [batch_size, 1, 1])
    x_t = tf.cast(x_t, dtype = tf.float32)
    
    y_t = -1 + tf.range(hight) / hight * 2
    y_t = tf.tile(tf.reshape(y_t, [hight, 1]), [1, width])
    y_t = tf.reshape(y_t, [1, hight * width])
    y_t = tf.expand_dims(y_t, axis = 0)
    y_t = tf.tile(y_t, [batch_size, 1, 1])
    y_t = tf.cast(y_t, dtype = tf.float32)

    ones_t = tf.ones(shape = [batch_size, 1, hight * width], dtype = tf.float32)

    grids_t = tf.concat([x_t, y_t, ones_t], axis = 1)
    
    theta = tf.reshape(theta, [batch_size, 2, 3])
    last_row = tf.constant(value = [[[0.0, 0.0, 1.0]]], dtype = tf.float32)
    last_row = tf.tile(last_row, [batch_size, 1, 1])
    A = tf.concat([theta, last_row], axis = 1)

    grids_s = tf.matmul(A, grids_t)
    grids_s = tf.concat([grids_s[:, :1, :], grids_s[:, 1:2, :]], axis = 1)

    return grids_s
  
  def sampler(self, u, grids, batch_size, hight, width):
    img_in = tf.reshape(u, [-1, hight, width])
    
    x = (1.0 + grids[:, 0]) * width / 2.0
    y = (1.0 + grids[:, 1]) * hight / 2.0

    x0 = tf.cast(tf.floor(x), tf.int32) 
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), tf.int32) 
    y1 = y0 + 1

    w_ul = (tf.cast(x1, tf.float32) - x) * (tf.cast(y1, tf.float32) - y)
    w_ll = (tf.cast(x1, tf.float32) - x) * (y - tf.cast(y0, tf.float32))
    w_ur = (x - tf.cast(x0, tf.float32)) * (tf.cast(y1, tf.float32) - y)
    w_lr = (x - tf.cast(x0, tf.float32)) * (y - tf.cast(y0, tf.float32))

    x0_img = tf.maximum(0, tf.minimum(x0, width - 1))
    x1_img = tf.maximum(0, tf.minimum(x1, width - 1))
    y0_img = tf.maximum(0, tf.minimum(y0, hight - 1))
    y1_img = tf.maximum(0, tf.minimum(y1, hight - 1))
    
    x0_img = tf.expand_dims(x0_img, axis = 2)
    x1_img = tf.expand_dims(x1_img, axis = 2)
    y0_img = tf.expand_dims(y0_img, axis = 2)
    y1_img = tf.expand_dims(y1_img, axis = 2)

    idx_ul = tf.concat([y0_img, x0_img], axis = 2)
    idx_ll = tf.concat([y1_img, x0_img], axis = 2)
    idx_ur = tf.concat([y0_img, x1_img], axis = 2)
    idx_lr = tf.concat([y1_img, x1_img], axis = 2)

    idx_ul = tf.reshape(idx_ul, [batch_size * 28 * 28, 2])
    idx_ll = tf.reshape(idx_ll, [batch_size * 28 * 28, 2])
    idx_ur = tf.reshape(idx_ur, [batch_size * 28 * 28, 2])
    idx_lr = tf.reshape(idx_lr, [batch_size * 28 * 28, 2])
    
    idx_batch = np.reshape(np.arange(batch_size), [batch_size, 1])
    idx_batch = np.repeat(idx_batch, 28 * 28, axis = 0)
    idx_batch = tf.convert_to_tensor(idx_batch, dtype = tf.int32)

    idx_ul = tf.concat([idx_batch, idx_ul], axis = 1)
    idx_ll = tf.concat([idx_batch, idx_ll], axis = 1)
    idx_ur = tf.concat([idx_batch, idx_ur], axis = 1)
    idx_lr = tf.concat([idx_batch, idx_lr], axis = 1)
    
    img_ul = tf.gather_nd(img_in, idx_ul)
    img_ll = tf.gather_nd(img_in, idx_ll)
    img_ur = tf.gather_nd(img_in, idx_ur)
    img_lr = tf.gather_nd(img_in, idx_lr)
    
    img_ul = tf.reshape(img_ul, [batch_size, 28 * 28])
    img_ll = tf.reshape(img_ll, [batch_size, 28 * 28])
    img_ur = tf.reshape(img_ur, [batch_size, 28 * 28])
    img_lr = tf.reshape(img_lr, [batch_size, 28 * 28])
       
    v = w_ul * img_ul + w_ll * img_ll + w_ur * img_ur + w_lr * img_lr
    
    return v
    
  # With Spatial Transformer
  def inference(self, u, hight, width, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size):
    
    with tf.variable_scope('loc'):
      theta = self.localization_net(u, n_loc_fc, keep_prob)    
      grids_s = self.grid_generator(batch_size, hight, width, theta)
      v = self.sampler(u, grids_s, batch_size, hight, width)
      v_reshaped = tf.reshape(v, [-1, 28, 28, 1])
    
    with tf.variable_scope('conv_1'):
      w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
      b = self.bias_variable('b', [n_filters_1])
      
      # no max_pooling
      conv_1 = tf.nn.relu(tf.nn.conv2d(v_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
      
      # with max_pooling
      #conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
      
    with tf.variable_scope('conv_2'):
      w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
      b = self.bias_variable('b', [n_filters_2])
      
      # no max_pooling
      conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])
    
    with tf.variable_scope('fc_1'):
      w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
      b = self.bias_variable('b', [n_fc])
      
      fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)
      
    fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)
      
    with tf.variable_scope('fc_2'):
      w = self.weight_variable('w', [n_fc, 10])
      b = self.bias_variable('b', [10])
      
      fc_2 = tf.matmul(fc_1_dropout, w) + b
      
    output = tf.nn.softmax(fc_2, axis = 1)
      
    return output

  # Without Spatial Transformer
  def inference_2(self, x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob):
    x_reshaped = tf.reshape(x, [-1, 28, 28, 1])
    
    with tf.variable_scope('conv_1'):
      w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
      b = self.bias_variable('b', [n_filters_1])
      
      # no max_pooling
      conv_1 = tf.nn.relu(tf.nn.conv2d(x_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
      
      # with max_pooling
      #conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
      
    with tf.variable_scope('conv_2'):
      w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
      b = self.bias_variable('b', [n_filters_2])
      
      # no max_pooling
      conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)

      # with max_pooling
      #conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
      #conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

    conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])
    
    with tf.variable_scope('fc_1'):
      w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
      b = self.bias_variable('b', [n_fc])
      
      fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)
      
    fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)
      
    with tf.variable_scope('fc_2'):
      w = self.weight_variable('w', [n_fc, 10])
      b = self.bias_variable('b', [10])
      
      fc_2 = tf.matmul(fc_1_dropout, w) + b
      
    output = tf.nn.softmax(fc_2, axis = 1)
      
    return output
          
  def loss(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy
  
  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy
  
  def training(self, loss, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss)
    return train_step
  
  def training_clipped(self, loss, learning_rate, clip_norm):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)
    
    return train_step

  def fit(self, images_train, labels_train, images_test, labels_Test, \
          n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, \
          learning_rate, n_iter, batch_size, show_step, is_saving, model_path):
        
    tf.reset_default_graph()
    
    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    t = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    keep_prob = tf.placeholder(shape = (), dtype = tf.float32)
    
    # With Spatial Transformer
    #y = self.inference(x, 28, 28, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size)
    
    # Without Spatial Transformer
    y = self.inference_2(x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob)

    loss = self.loss(y, t)
    
    # Without Gradient Clipping
    train_step = self.training(loss, learning_rate)
    # With Gradient Clipping
    #train_step = self.training_clipped(loss, learning_rate, 0.1)

    acc =  self.accuracy(y, t)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
      
      sess.run(init)
      
      history_loss_train = []
      history_acc_train = []
      history_loss_test = []
      history_acc_test = []
      
      for i in range(n_iter):
        # Train
        rand_index = np.random.choice(len(images_train), size = batch_size)
        x_batch = images_train[rand_index]
        y_batch = labels_train[rand_index]

        feed_dict = {x: x_batch, t: y_batch, keep_prob: 0.7}

        sess.run(train_step, feed_dict = feed_dict)

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_train.append(temp_loss)
        history_acc_train.append(temp_acc)

        if (i + 1) % show_step == 0:
          print ('--------------------')
          print ('Iteration: ' + str(i + 1) + '  Loss: ' + str(temp_loss) + \
                '  Accuracy: ' + str(temp_acc))
                   
        # Test
        rand_index = np.random.choice(len(images_test), size = batch_size)
        x_batch = images_test[rand_index]
        y_batch = labels_test[rand_index]

        feed_dict = {x: x_batch, t: y_batch, keep_prob: 1.0}

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_test.append(temp_loss)
        history_acc_test.append(temp_acc)

      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

    fig = plt.figure(figsize = (10, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
    ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
    ax1.set_title('Loss')
    ax1.legend(loc = 'upper right')

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
    ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
    ax2.set_title('Accuracy')
    ax2.legend(loc = 'lower right')

    plt.show()
    
  def check_results(self, u, batch_size, hight, width, n_loc_fc, model_path):    
    
    with tf.variable_scope('loc', reuse = True):
      theta = self.localization_net(u, n_loc_fc, 1.0)    
      grids_s = self.grid_generator(batch_size, hight, width, theta)
      v = self.sampler(u, grids_s, batch_size, hight, width)
      v_reshaped = tf.reshape(v, [-1, 28, 28, 1])
    
    saver = tf.train.Saver()

    with tf.Session() as sess:
      
      saver.restore(sess, model_path)
      
      ret_theta = sess.run(theta)
      ret_v = sess.run(v)
      
    return ret_theta, ret_v
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?