Neural (Decision) Tree の実装に関するメモ

Last updated at Posted at 2018-08-19

Distilling a Neural Network Into a Soft Decision Tree



images_train = mnist.train.images
labels_train = mnist.train.labels
images_test = mnist.test.images
labels_test = mnist.test.labels

##Sample Code

class DecisionTree():
  def __init__(self):
  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  def tree(self, x, y, n_in, batch_size, lam, reuse = False):

    pathprob = []
    pathprob_l = []
    prob_i = []
    prob_l = []
    loss_i = []
    loss_l = []
    pathprob.append(tf.ones(shape = [batch_size, 1], dtype = tf.float32))
    with tf.variable_scope('n_0', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      pathprob.append(1.0 - p)
      alpha = tf.reduce_mean(pathprob[0] * prob_i[0]) / (tf.reduce_mean(pathprob[0]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
    with tf.variable_scope('n_1', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      pathprob.append(pathprob[1] * p)
      pathprob.append(pathprob[1] * (1.0 - p))
      alpha = tf.reduce_mean(pathprob[1] * prob_i[1]) / (tf.reduce_mean(pathprob[1]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
    with tf.variable_scope('n_2', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      pathprob.append(pathprob[2] * p)
      pathprob.append(pathprob[2] * (1.0 - p))
      alpha = tf.reduce_mean(pathprob[2] * prob_i[2]) / (tf.reduce_mean(pathprob[2]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))

    # leaf  
    with tf.variable_scope('n_3', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])
      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[3]

    with tf.variable_scope('n_4', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])

      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[4]

    with tf.variable_scope('n_5', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])
      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[5]

    with tf.variable_scope('n_6', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])

      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[6]
    loss_total = tf.reduce_sum(loss_l) + lam * tf.reduce_mean(loss_i)
    pathprob = tf.transpose(pathprob, [1, 0, 2])
    pathprob_l = pathprob[:, 3:, :]
    prob_l = tf.transpose(prob_l, [1, 0, 2])
    return loss_total, pathprob_l, prob_l
  def accuracy(self, pathprob, prob, t):
    pathprob = tf.tile(pathprob, [1, 1, 10])
    conditional = tf.multiply(pathprob, prob)
    split_1, split_2, split_3, split_4 = tf.split(conditional, 4, axis= 1)
    y = split_1 + split_2 + split_3 + split_4
    y = tf.squeeze(y)
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy, y

  def training(self, loss, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss)
    return train_step
  def fit(self, images_train, labels_train, images_test, labels_test, \
          lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path):
    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    loss, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, lam, reuse = False)
    train_step = self.training(loss, learning_rate)
    acc, _ = self.accuracy(pathprob_l, prob_l, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      history_loss_train = []
      history_loss_test = []
      history_acc_train = []
      history_acc_test = []

      for i in range(n_iter):
        # Train
        rand_index = np.random.choice(len(images_train), size = batch_size)
        x_batch = images_train[rand_index]
        y_batch = labels_train[rand_index]

        feed_dict = {x: x_batch, y: y_batch}

        sess.run(train_step, feed_dict = feed_dict)

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)


        if (i + 1) % show_step == 0:
          print ('-' * 15)
          print ('Iteration: ' + str(i + 1) + '  Loss: ' + str(temp_loss) + \
                '  Accuracy: ' + str(temp_acc))

        # Test
        rand_index = np.random.choice(len(images_test), size = batch_size)
        x_batch = images_test[rand_index]
        y_batch = labels_test[rand_index]

        feed_dict = {x: x_batch, y: y_batch}

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)


      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
      ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
      ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.legend(loc = 'lower right')

  def predict(self, images, labels, batch_size, model_path):
    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    _, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, 1.0, reuse = True)
    _, y_hat = self.accuracy(pathprob_l, prob_l, y)
    saver = tf.train.Saver()

    with tf.Session() as sess:

      saver.restore(sess, model_path)
      feed_dict = {x: images, y: labels}
      return sess.run([pathprob_l, prob_l, y_hat], feed_dict = feed_dict)


lam = 10.0
learning_rate = 0.01
n_iter = 200
batch_size = 64
show_step = 100
model_path = 'datalab/model'


dt.fit(images_train, labels_train, images_test, labels_test, \
       lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path)


index = np.random.choice(10000, 10)
images = images_train[index]
labels = labels_train[index]

preds = dt.predict(images, labels, 10, model_path) 

#print (np.shape(preds[0]))
#print (np.shape(preds[1]))
#print (np.shape(preds[2]))

print ('-' * 15)
print ('Prediction: ')
print (np.argmax(preds[2],axis = 1))
print ('True: ')
print (np.argmax(labels, axis = 1))
print ('Leaf: ')
print (np.argmax(np.reshape(preds[0], (-1, 4)), axis = 1))



