##Reference
Distilling a Neural Network Into a Soft Decision Tree
##Data
images_train = mnist.train.images
labels_train = mnist.train.labels
images_test = mnist.test.images
labels_test = mnist.test.labels
##Sample Code
class DecisionTree():
def __init__(self):
pass
def weight_variable(self, name, shape):
initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
return tf.get_variable(name, shape, initializer = initializer)
def bias_variable(self, name, shape):
initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
return tf.get_variable(name, shape, initializer = initializer)
def tree(self, x, y, n_in, batch_size, lam, reuse = False):
pathprob = []
pathprob_l = []
prob_i = []
prob_l = []
loss_i = []
loss_l = []
pathprob.append(tf.ones(shape = [batch_size, 1], dtype = tf.float32))
with tf.variable_scope('n_0', reuse = reuse):
w = self.weight_variable('w', [n_in, 1])
b = self.bias_variable('b', [1])
p = tf.nn.sigmoid(tf.matmul(x, w) + b)
prob_i.append(p)
pathprob.append(p)
pathprob.append(1.0 - p)
alpha = tf.reduce_mean(pathprob[0] * prob_i[0]) / (tf.reduce_mean(pathprob[0]) + 1e-10)
loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
-0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
loss_i.append(loss)
with tf.variable_scope('n_1', reuse = reuse):
w = self.weight_variable('w', [n_in, 1])
b = self.bias_variable('b', [1])
p = tf.nn.sigmoid(tf.matmul(x, w) + b)
prob_i.append(p)
pathprob.append(pathprob[1] * p)
pathprob.append(pathprob[1] * (1.0 - p))
alpha = tf.reduce_mean(pathprob[1] * prob_i[1]) / (tf.reduce_mean(pathprob[1]) + 1e-10)
loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
-0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
loss_i.append(loss)
with tf.variable_scope('n_2', reuse = reuse):
w = self.weight_variable('w', [n_in, 1])
b = self.bias_variable('b', [1])
p = tf.nn.sigmoid(tf.matmul(x, w) + b)
prob_i.append(p)
pathprob.append(pathprob[2] * p)
pathprob.append(pathprob[2] * (1.0 - p))
alpha = tf.reduce_mean(pathprob[2] * prob_i[2]) / (tf.reduce_mean(pathprob[2]) + 1e-10)
loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
-0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
loss_i.append(loss)
# leaf
with tf.variable_scope('n_3', reuse = reuse):
w = self.weight_variable('w', [n_in, 10])
b = self.bias_variable('b', [10])
p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[3]
prob_l.append(p)
loss_l.append(loss)
with tf.variable_scope('n_4', reuse = reuse):
w = self.weight_variable('w', [n_in, 10])
b = self.bias_variable('b', [10])
p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[4]
prob_l.append(p)
loss_l.append(loss)
with tf.variable_scope('n_5', reuse = reuse):
w = self.weight_variable('w', [n_in, 10])
b = self.bias_variable('b', [10])
p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[5]
prob_l.append(p)
loss_l.append(loss)
with tf.variable_scope('n_6', reuse = reuse):
w = self.weight_variable('w', [n_in, 10])
b = self.bias_variable('b', [10])
p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[6]
prob_l.append(p)
loss_l.append(loss)
loss_total = tf.reduce_sum(loss_l) + lam * tf.reduce_mean(loss_i)
pathprob = tf.transpose(pathprob, [1, 0, 2])
pathprob_l = pathprob[:, 3:, :]
prob_l = tf.transpose(prob_l, [1, 0, 2])
return loss_total, pathprob_l, prob_l
def accuracy(self, pathprob, prob, t):
pathprob = tf.tile(pathprob, [1, 1, 10])
conditional = tf.multiply(pathprob, prob)
split_1, split_2, split_3, split_4 = tf.split(conditional, 4, axis= 1)
y = split_1 + split_2 + split_3 + split_4
y = tf.squeeze(y)
correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
return accuracy, y
def training(self, loss, learning_rate):
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
train_step = optimizer.minimize(loss)
return train_step
def fit(self, images_train, labels_train, images_test, labels_test, \
lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path):
tf.reset_default_graph()
x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
loss, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, lam, reuse = False)
train_step = self.training(loss, learning_rate)
acc, _ = self.accuracy(pathprob_l, prob_l, y)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
history_loss_train = []
history_loss_test = []
history_acc_train = []
history_acc_test = []
for i in range(n_iter):
# Train
rand_index = np.random.choice(len(images_train), size = batch_size)
x_batch = images_train[rand_index]
y_batch = labels_train[rand_index]
feed_dict = {x: x_batch, y: y_batch}
sess.run(train_step, feed_dict = feed_dict)
temp_loss = sess.run(loss, feed_dict = feed_dict)
temp_acc = sess.run(acc, feed_dict = feed_dict)
history_loss_train.append(temp_loss)
history_acc_train.append(temp_acc)
if (i + 1) % show_step == 0:
print ('-' * 15)
print ('Iteration: ' + str(i + 1) + ' Loss: ' + str(temp_loss) + \
' Accuracy: ' + str(temp_acc))
# Test
rand_index = np.random.choice(len(images_test), size = batch_size)
x_batch = images_test[rand_index]
y_batch = labels_test[rand_index]
feed_dict = {x: x_batch, y: y_batch}
temp_loss = sess.run(loss, feed_dict = feed_dict)
temp_acc = sess.run(acc, feed_dict = feed_dict)
history_loss_test.append(temp_loss)
history_acc_test.append(temp_acc)
if is_saving:
model_path = saver.save(sess, model_path)
print ('done saving at ', model_path)
print ('-'* 15)
fig = plt.figure(figsize = (10, 3))
ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
ax1.set_title('Loss')
ax1.legend(loc = 'upper right')
ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
ax2.set_ylim(0.0, 1.0)
ax2.set_title('Accuracy')
ax2.legend(loc = 'lower right')
plt.show()
def predict(self, images, labels, batch_size, model_path):
x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
_, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, 1.0, reuse = True)
_, y_hat = self.accuracy(pathprob_l, prob_l, y)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
feed_dict = {x: images, y: labels}
return sess.run([pathprob_l, prob_l, y_hat], feed_dict = feed_dict)
##Parameters
lam = 10.0
learning_rate = 0.01
n_iter = 200
batch_size = 64
show_step = 100
model_path = 'datalab/model'
##Output
dt.fit(images_train, labels_train, images_test, labels_test, \
lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path)
index = np.random.choice(10000, 10)
images = images_train[index]
labels = labels_train[index]
preds = dt.predict(images, labels, 10, model_path)
#print (np.shape(preds[0]))
#print (np.shape(preds[1]))
#print (np.shape(preds[2]))
print ('-' * 15)
print ('Prediction: ')
print (np.argmax(preds[2],axis = 1))
print ('True: ')
print (np.argmax(labels, axis = 1))
print ('Leaf: ')
print (np.argmax(np.reshape(preds[0], (-1, 4)), axis = 1))