0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Neural (Decision) Tree の実装に関するメモ

Last updated at Posted at 2018-08-19

##Reference
Distilling a Neural Network Into a Soft Decision Tree

image.png

##Data

images_train = mnist.train.images
labels_train = mnist.train.labels
images_test = mnist.test.images
labels_test = mnist.test.labels

##Sample Code

class DecisionTree():
  
  def __init__(self):
    pass
  
  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def tree(self, x, y, n_in, batch_size, lam, reuse = False):

    pathprob = []
    pathprob_l = []
    prob_i = []
    prob_l = []
    loss_i = []
    loss_l = []
    
    pathprob.append(tf.ones(shape = [batch_size, 1], dtype = tf.float32))
    
    with tf.variable_scope('n_0', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      prob_i.append(p)
      pathprob.append(p)
      pathprob.append(1.0 - p)
      alpha = tf.reduce_mean(pathprob[0] * prob_i[0]) / (tf.reduce_mean(pathprob[0]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
      loss_i.append(loss)
      
    with tf.variable_scope('n_1', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      prob_i.append(p)
      pathprob.append(pathprob[1] * p)
      pathprob.append(pathprob[1] * (1.0 - p))
      alpha = tf.reduce_mean(pathprob[1] * prob_i[1]) / (tf.reduce_mean(pathprob[1]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
      loss_i.append(loss)
      
    with tf.variable_scope('n_2', reuse = reuse):
      w = self.weight_variable('w', [n_in, 1])
      b = self.bias_variable('b', [1])
      
      p = tf.nn.sigmoid(tf.matmul(x, w) + b)
      prob_i.append(p)
      pathprob.append(pathprob[2] * p)
      pathprob.append(pathprob[2] * (1.0 - p))
      alpha = tf.reduce_mean(pathprob[2] * prob_i[2]) / (tf.reduce_mean(pathprob[2]) + 1e-10)
      loss = -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(alpha, 1e-10, 1.0))) \
                                   -0.5 * tf.reduce_mean(tf.log(tf.clip_by_value(1.0 - alpha, 1e-10, 1.0)))
      loss_i.append(loss)

    # leaf  
    with tf.variable_scope('n_3', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])
      
      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[3]
      prob_l.append(p)
      loss_l.append(loss)

    with tf.variable_scope('n_4', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])

      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[4]
      prob_l.append(p)
      loss_l.append(loss)

    with tf.variable_scope('n_5', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])
      
      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[5]
      prob_l.append(p)
      loss_l.append(loss)

    with tf.variable_scope('n_6', reuse = reuse):
      w = self.weight_variable('w', [n_in, 10])
      b = self.bias_variable('b', [10])

      p = tf.nn.softmax(tf.matmul(x, w) + b, axis = 1)
      loss = - tf.reduce_mean(tf.reduce_sum(y * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1)) * pathprob[6]
      prob_l.append(p)
      loss_l.append(loss)
      
    loss_total = tf.reduce_sum(loss_l) + lam * tf.reduce_mean(loss_i)
    
    pathprob = tf.transpose(pathprob, [1, 0, 2])
    pathprob_l = pathprob[:, 3:, :]
    prob_l = tf.transpose(prob_l, [1, 0, 2])
    
    return loss_total, pathprob_l, prob_l
  
  def accuracy(self, pathprob, prob, t):
    pathprob = tf.tile(pathprob, [1, 1, 10])
    conditional = tf.multiply(pathprob, prob)
    split_1, split_2, split_3, split_4 = tf.split(conditional, 4, axis= 1)
    y = split_1 + split_2 + split_3 + split_4
    y = tf.squeeze(y)
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy, y

  def training(self, loss, learning_rate):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss)
    return train_step
  
  def fit(self, images_train, labels_train, images_test, labels_test, \
          lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path):
    
    tf.reset_default_graph()
    
    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    
    loss, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, lam, reuse = False)
    train_step = self.training(loss, learning_rate)
    acc, _ = self.accuracy(pathprob_l, prob_l, y)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)
      
      history_loss_train = []
      history_loss_test = []
      history_acc_train = []
      history_acc_test = []

      for i in range(n_iter):
        # Train
        rand_index = np.random.choice(len(images_train), size = batch_size)
        x_batch = images_train[rand_index]
        y_batch = labels_train[rand_index]

        feed_dict = {x: x_batch, y: y_batch}

        sess.run(train_step, feed_dict = feed_dict)

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_train.append(temp_loss)
        history_acc_train.append(temp_acc)

        if (i + 1) % show_step == 0:
          print ('-' * 15)
          print ('Iteration: ' + str(i + 1) + '  Loss: ' + str(temp_loss) + \
                '  Accuracy: ' + str(temp_acc))

        # Test
        rand_index = np.random.choice(len(images_test), size = batch_size)
        x_batch = images_test[rand_index]
        y_batch = labels_test[rand_index]

        feed_dict = {x: x_batch, y: y_batch}

        temp_loss = sess.run(loss, feed_dict = feed_dict)
        temp_acc = sess.run(acc, feed_dict = feed_dict)

        history_loss_test.append(temp_loss)
        history_acc_test.append(temp_acc)

      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Train')
      ax1.plot(range(n_iter), history_loss_test, 'r--', label = 'Test')
      ax1.set_title('Loss')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Train')
      ax2.plot(range(n_iter), history_acc_test, 'r--', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy')
      ax2.legend(loc = 'lower right')

      plt.show()
  
  def predict(self, images, labels, batch_size, model_path):
        
    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    
    _, pathprob_l, prob_l = self.tree(x, y, 28*28, batch_size, 1.0, reuse = True)
    _, y_hat = self.accuracy(pathprob_l, prob_l, y)
    
    saver = tf.train.Saver()

    with tf.Session() as sess:

      saver.restore(sess, model_path)
      
      feed_dict = {x: images, y: labels}
      
      return sess.run([pathprob_l, prob_l, y_hat], feed_dict = feed_dict)

##Parameters

lam = 10.0
learning_rate = 0.01
n_iter = 200
batch_size = 64
show_step = 100
model_path = 'datalab/model'

##Output

dt.fit(images_train, labels_train, images_test, labels_test, \
       lam, learning_rate, n_iter, batch_size, show_step, is_saving, model_path)

image.png

index = np.random.choice(10000, 10)
images = images_train[index]
labels = labels_train[index]

preds = dt.predict(images, labels, 10, model_path) 

#print (np.shape(preds[0]))
#print (np.shape(preds[1]))
#print (np.shape(preds[2]))

print ('-' * 15)
print ('Prediction: ')
print (np.argmax(preds[2],axis = 1))
print ('True: ')
print (np.argmax(labels, axis = 1))
print ('Leaf: ')
print (np.argmax(np.reshape(preds[0], (-1, 4)), axis = 1))

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?