ポイント
- Spatial Transformer Network を実装し、具体的な数値で確認。
- 今後、パフォーマンス検証を実施。
レファレンス
1. Spatial Transformer Networks
(参照論文より引用)
検証方法
- Distorted MNIST を使用。
- Spatial Transformer ある、なしで Loss と Accuracy を比較。
データ
検証結果
n_loc_fc = 20
filter_size = 3
n_filters_1 = 16
n_filters_2 = 16
n_fc = 1024
learning_rate = 0.001
batch_size = 64
サンプルコード
class SpatialTransformer():
def __init__(self):
pass
def weight_variable(self, name, shape):
initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
return tf.get_variable(name, shape, initializer = initializer)
def bias_variable(self, name, shape):
initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
return tf.get_variable(name, shape, initializer = initializer)
def localization_net(self, x, n_units, keep_prob):
w_1 = self.weight_variable('w_1', [28 * 28, n_units])
b_1 = self.bias_variable('b_1', [n_units])
loc_1 = tf.nn.tanh(tf.matmul(x, w_1) + b_1)
loc_1_dropout = tf.nn.dropout(loc_1, keep_prob)
w_2 = self.weight_variable('w_2', [n_units, 6])
value = np.array([[1, 0, 0], [0, 1, 0]]).astype(np.float32)
value = value.flatten()
init = tf.constant_initializer(value = value, dtype = tf.float32)
b_2 = tf.get_variable('b_2', shape = [6], initializer = init)
loc_2 = tf.nn.tanh(tf.matmul(loc_1_dropout, w_2) + b_2)
return loc_2
def grid_generator(self, batch_size, hight, width, theta):
x_t = -1 + tf.range(width) / width * 2
x_t = tf.reshape(tf.concat([x_t] * hight, axis = 0), [1, hight * width])
x_t = tf.expand_dims(x_t, axis = 0)
x_t = tf.tile(x_t, [batch_size, 1, 1])
x_t = tf.cast(x_t, dtype = tf.float32)
y_t = -1 + tf.range(hight) / hight * 2
y_t = tf.tile(tf.reshape(y_t, [hight, 1]), [1, width])
y_t = tf.reshape(y_t, [1, hight * width])
y_t = tf.expand_dims(y_t, axis = 0)
y_t = tf.tile(y_t, [batch_size, 1, 1])
y_t = tf.cast(y_t, dtype = tf.float32)
ones_t = tf.ones(shape = [batch_size, 1, hight * width], dtype = tf.float32)
grids_t = tf.concat([x_t, y_t, ones_t], axis = 1)
theta = tf.reshape(theta, [batch_size, 2, 3])
last_row = tf.constant(value = [[[0.0, 0.0, 1.0]]], dtype = tf.float32)
last_row = tf.tile(last_row, [batch_size, 1, 1])
A = tf.concat([theta, last_row], axis = 1)
grids_s = tf.matmul(A, grids_t)
grids_s = tf.concat([grids_s[:, :1, :], grids_s[:, 1:2, :]], axis = 1)
return grids_s
def sampler(self, u, grids, batch_size, hight, width):
img_in = tf.reshape(u, [-1, hight, width])
x = (1.0 + grids[:, 0]) * width / 2.0
y = (1.0 + grids[:, 1]) * hight / 2.0
x0 = tf.cast(tf.floor(x), tf.int32)
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), tf.int32)
y1 = y0 + 1
w_ul = (tf.cast(x1, tf.float32) - x) * (tf.cast(y1, tf.float32) - y)
w_ll = (tf.cast(x1, tf.float32) - x) * (y - tf.cast(y0, tf.float32))
w_ur = (x - tf.cast(x0, tf.float32)) * (tf.cast(y1, tf.float32) - y)
w_lr = (x - tf.cast(x0, tf.float32)) * (y - tf.cast(y0, tf.float32))
x0_img = tf.maximum(0, tf.minimum(x0, width - 1))
x1_img = tf.maximum(0, tf.minimum(x1, width - 1))
y0_img = tf.maximum(0, tf.minimum(y0, hight - 1))
y1_img = tf.maximum(0, tf.minimum(y1, hight - 1))
x0_img = tf.expand_dims(x0_img, axis = 2)
x1_img = tf.expand_dims(x1_img, axis = 2)
y0_img = tf.expand_dims(y0_img, axis = 2)
y1_img = tf.expand_dims(y1_img, axis = 2)
idx_ul = tf.concat([y0_img, x0_img], axis = 2)
idx_ll = tf.concat([y1_img, x0_img], axis = 2)
idx_ur = tf.concat([y0_img, x1_img], axis = 2)
idx_lr = tf.concat([y1_img, x1_img], axis = 2)
idx_ul = tf.reshape(idx_ul, [batch_size * 28 * 28, 2])
idx_ll = tf.reshape(idx_ll, [batch_size * 28 * 28, 2])
idx_ur = tf.reshape(idx_ur, [batch_size * 28 * 28, 2])
idx_lr = tf.reshape(idx_lr, [batch_size * 28 * 28, 2])
idx_batch = np.reshape(np.arange(batch_size), [batch_size, 1])
idx_batch = np.repeat(idx_batch, 28 * 28, axis = 0)
idx_batch = tf.convert_to_tensor(idx_batch, dtype = tf.int32)
idx_ul = tf.concat([idx_batch, idx_ul], axis = 1)
idx_ll = tf.concat([idx_batch, idx_ll], axis = 1)
idx_ur = tf.concat([idx_batch, idx_ur], axis = 1)
idx_lr = tf.concat([idx_batch, idx_lr], axis = 1)
img_ul = tf.gather_nd(img_in, idx_ul)
img_ll = tf.gather_nd(img_in, idx_ll)
img_ur = tf.gather_nd(img_in, idx_ur)
img_lr = tf.gather_nd(img_in, idx_lr)
img_ul = tf.reshape(img_ul, [batch_size, 28 * 28])
img_ll = tf.reshape(img_ll, [batch_size, 28 * 28])
img_ur = tf.reshape(img_ur, [batch_size, 28 * 28])
img_lr = tf.reshape(img_lr, [batch_size, 28 * 28])
v = w_ul * img_ul + w_ll * img_ll + w_ur * img_ur + w_lr * img_lr
return v
# With Spatial Transformer
def inference(self, u, hight, width, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size):
with tf.variable_scope('loc'):
theta = self.localization_net(u, n_loc_fc, keep_prob)
grids_s = self.grid_generator(batch_size, hight, width, theta)
v = self.sampler(u, grids_s, batch_size, hight, width)
v_reshaped = tf.reshape(v, [-1, 28, 28, 1])
with tf.variable_scope('conv_1'):
w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
b = self.bias_variable('b', [n_filters_1])
# no max_pooling
conv_1 = tf.nn.relu(tf.nn.conv2d(v_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
# with max_pooling
#conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
#conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
with tf.variable_scope('conv_2'):
w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
b = self.bias_variable('b', [n_filters_2])
# no max_pooling
conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
# with max_pooling
#conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
#conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])
with tf.variable_scope('fc_1'):
w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
b = self.bias_variable('b', [n_fc])
fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)
fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)
with tf.variable_scope('fc_2'):
w = self.weight_variable('w', [n_fc, 10])
b = self.bias_variable('b', [10])
fc_2 = tf.matmul(fc_1_dropout, w) + b
output = tf.nn.softmax(fc_2, axis = 1)
return output
# Without Spatial Transformer
def inference_2(self, x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob):
x_reshaped = tf.reshape(x, [-1, 28, 28, 1])
with tf.variable_scope('conv_1'):
w = self.weight_variable('w', [filter_size, filter_size, 1, n_filters_1])
b = self.bias_variable('b', [n_filters_1])
# no max_pooling
conv_1 = tf.nn.relu(tf.nn.conv2d(x_reshaped, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
# with max_pooling
#conv_1 = tf.nn.relu(tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
#conv_1 = tf.nn.max_pool(conv_1, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
with tf.variable_scope('conv_2'):
w = self.weight_variable('w', [filter_size, filter_size, n_filters_1, n_filters_2])
b = self.bias_variable('b', [n_filters_2])
# no max_pooling
conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 2, 2, 1], padding = 'SAME') + b)
# with max_pooling
#conv_2 = tf.nn.relu(tf.nn.conv2d(conv_1, w, strides = [1, 1, 1, 1], padding = 'SAME') + b)
#conv_2 = tf.nn.max_pool(conv_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
conv_2_flat = tf.reshape(conv_2, [-1, 7 * 7 * n_filters_2])
with tf.variable_scope('fc_1'):
w = self.weight_variable('w', [7 * 7 * n_filters_2, n_fc])
b = self.bias_variable('b', [n_fc])
fc_1 = tf.nn.relu(tf.matmul(conv_2_flat, w) + b)
fc_1_dropout = tf.nn.dropout(fc_1, keep_prob)
with tf.variable_scope('fc_2'):
w = self.weight_variable('w', [n_fc, 10])
b = self.bias_variable('b', [10])
fc_2 = tf.matmul(fc_1_dropout, w) + b
output = tf.nn.softmax(fc_2, axis = 1)
return output
def loss(self, y, t):
cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
return cross_entropy
def accuracy(self, y, t):
correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
return accuracy
def training(self, loss, learning_rate):
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
train_step = optimizer.minimize(loss)
return train_step
def training_clipped(self, loss, learning_rate, clip_norm):
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
var) for grad, var in grads_and_vars]
train_step = optimizer.apply_gradients(clipped_grads_and_vars)
return train_step
def fit(self, images_train, labels_train, images_test, labels_Test, \
n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, \
learning_rate, n_iter, batch_size, show_step, is_saving, model_path):
tf.reset_default_graph()
x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
t = tf.placeholder(shape = [None, 10], dtype = tf.float32)
keep_prob = tf.placeholder(shape = (), dtype = tf.float32)
# With Spatial Transformer
#y = self.inference(x, 28, 28, n_loc_fc, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob, batch_size)
# Without Spatial Transformer
y = self.inference_2(x, filter_size, n_filters_1, n_filters_2, n_fc, keep_prob)
loss = self.loss(y, t)
# Without Gradient Clipping
train_step = self.training(loss, learning_rate)
# With Gradient Clipping
#train_step = self.training_clipped(loss, learning_rate, 0.1)
acc = self.accuracy(y, t)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
history_loss_train = []
history_acc_train = []
history_loss_test = []
history_acc_test = []
for i in range(n_iter):
# Train
rand_index = np.random.choice(len(images_train), size = batch_size)
x_batch = images_train[rand_index]
y_batch = labels_train[rand_index]
feed_dict = {x: x_batch, t: y_batch, keep_prob: 0.7}
sess.run(train_step, feed_dict = feed_dict)
temp_loss = sess.run(loss, feed_dict = feed_dict)
temp_acc = sess.run(acc, feed_dict = feed_dict)
history_loss_train.append(temp_loss)
history_acc_train.append(temp_acc)
if (i + 1) % show_step == 0:
print ('--------------------')
print ('Iteration: ' + str(i + 1) + ' Loss: ' + str(temp_loss) + \
' Accuracy: ' + str(temp_acc))
# Test
rand_index = np.random.choice(len(images_test), size = batch_size)
x_batch = images_test[rand_index]
y_batch = labels_test[rand_index]
feed_dict = {x: x_batch, t: y_batch, keep_prob: 1.0}
temp_loss = sess.run(loss, feed_dict = feed_dict)
temp_acc = sess.run(acc, feed_dict = feed_dict)
history_loss_test.append(temp_loss)
history_acc_test.append(temp_acc)
if is_saving:
model_path = saver.save(sess, model_path)
print ('done saving at ', model_path)
fig = plt.figure(figsize = (10, 3))
ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
ax1.set_title('Loss')
ax1.legend(loc = 'upper right')
ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
ax2.set_title('Accuracy')
ax2.legend(loc = 'lower right')
plt.show()
def check_results(self, u, batch_size, hight, width, n_loc_fc, model_path):
with tf.variable_scope('loc', reuse = True):
theta = self.localization_net(u, n_loc_fc, 1.0)
grids_s = self.grid_generator(batch_size, hight, width, theta)
v = self.sampler(u, grids_s, batch_size, hight, width)
v_reshaped = tf.reshape(v, [-1, 28, 28, 1])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
ret_theta = sess.run(theta)
ret_v = sess.run(v)
return ret_theta, ret_v