1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Adversarial Discriminative Domain Adaptation (ADDA) の実装に関するメモ

Posted at

##Reference
Adversarial Discriminative Domain Adaptation

image.png

##Data

images_train = mnist.train.images
labels_train = mnist.train.labels
images_test = mnist.test.images
labels_test = mnist.test.labels

import skimage.transform

# for train data
indices = np.random.choice(55000, 10000, replace = False)

images_train_0 = images_train[indices]
images_train_0_2d = np.reshape(images_train_0, (-1, 28, 28))
labels_train_0 = labels_train[indices]

#images_train_flip_2d = images_train_0_2d[:, :, ::-1]
#images_train_flip = np.reshape(images_train_0_2d[:, :, ::-1], (-1, 28*28))
#labels_train_flip = labels_train[indices]

#images_train_30_2d = []
#for i in range(len(images_train_0)):
#  images_train_30_2d.append(skimage.transform.rotate(images_train_0_2d[i], 30))
#images_train_30 = np.reshape(images_train_30_2d, (-1, 28*28))
#labels_train_30 = labels_train[indices] 

images_train_60_2d = []
for i in range(len(images_train_0)):
  images_train_60_2d.append(skimage.transform.rotate(images_train_0_2d[i], 60))
images_train_60 = np.reshape(images_train_60_2d, (-1, 28*28))
labels_train_60 = labels_train[indices] 

#images_train_90_2d = []
#for i in range(len(images_train_0)):
#  images_train_90_2d.append(skimage.transform.rotate(images_train_0_2d[i], 90))
#images_train_90 = np.reshape(images_train_90_2d, (-1, 28*28))
#labels_train_90 = labels_train[indices] 

#images_train_180_2d = []
#for i in range(len(images_train_0)):
#  images_train_180_2d.append(skimage.transform.rotate(images_train_0_2d[i], 180))
#images_train_180 = np.reshape(images_train_180_2d, (-1, 28*28))
#labels_train_180 = labels_train[indices] 

import skimage.transform

# for test data
indices = np.random.choice(10000, 10000, replace = False)

images_test_0 = images_test[indices]
images_test_0_2d = np.reshape(images_test_0, (-1, 28, 28))
labels_test_0 = labels_test[indices]

#images_test_flip_2d = images_test_0_2d[:, :, ::-1]
#images_test_flip = np.reshape(images_test_flip_2d, (-1, 28*28))
#labels_test_flip = labels_test[indices]

#images_test_30_2d = []
#for i in range(len(images_test_0)):
#  images_test_30_2d.append(skimage.transform.rotate(images_test_0_2d[i], 30))
#images_test_30 = np.reshape(images_test_30_2d, (-1, 28*28))
#labels_test_30 = labels_test[indices] 

images_test_60_2d = []
for i in range(len(images_test_0)):
  images_test_60_2d.append(skimage.transform.rotate(images_test_0_2d[i], 60))
images_test_60 = np.reshape(images_test_60_2d, (-1, 28*28))
labels_test_60 = labels_test[indices] 

#images_test_90_2d = []
#for i in range(len(images_test_0)):
#  images_test_90_2d.append(skimage.transform.rotate(images_test_0_2d[i], 90))
#images_test_90 = np.reshape(images_test_90_2d, (-1, 28*28))
#labels_test_90 = labels_test[indices] 

#images_test_180_2d = []
#for i in range(len(images_test_0)):
#  images_test_180_2d.append(skimage.transform.rotate(images_test_0_2d[i], 180))
#images_test_180 = np.reshape(images_test_180_2d, (-1, 28*28))
#labels_test_180 = labels_test[indices] 

image.png

##Sample Code

# Adversarial Discriminative Domain Adaptation

class ADDA():
  def __init__(self):
    pass

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def alpha_variable(self, name):
    initializer = tf.constant_initializer(value = 0.75, dtype = tf.float32)
    return tf.get_variable(name, shape = (), initializer = initializer)
  
  def generator_1(self, x, filter_size, n_filters_1, n_filters_2, n_units, keep_prob, reuse = False):
    x_reshaped = tf.reshape(x, [-1, 28, 28, 1])
    
    with tf.variable_scope('generator_1', reuse = reuse):
      w_1 = self.weight_variable('w_1', [filter_size, filter_size, 1, n_filters_1])
      b_1 = self.bias_variable('b_1', [n_filters_1])

      # conv
      conv = tf.nn.conv2d(x_reshaped, w_1, strides = [1, 2, 2, 1], padding = 'SAME') + b_1
    
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # relu
      conv = tf.nn.relu(conv)
    
      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
    
      w_2 = self.weight_variable('w_2', [filter_size, filter_size, n_filters_1, n_filters_2])
      b_2 = self.bias_variable('b_2', [n_filters_2])
      
      # conv
      conv = tf.nn.conv2d(conv, w_2, strides = [1, 2, 2, 1], padding = 'SAME') + b_2
    
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # relu
      conv = tf.nn.relu(conv)
    
      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
    
      conv_flat = tf.reshape(conv, [-1, 7 * 7 * n_filters_2])
    
      w_3 = self.weight_variable('w_3', [7 * 7 * n_filters_2, n_units])
      b_3 = self.bias_variable('b_3', [n_units])
    
      fc = tf.matmul(conv_flat, w_3) + b_3
      
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(fc, [0])
    
      #fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)
      
      # relu
      fc = tf.nn.relu(fc)
     
      # leaky relu
      #fc = tf.maximum(0.2 * fc, fc)
      
      feature = fc
      
    return feature

  def generator_t(self, x, filter_size, n_filters_1, n_filters_2, n_units, keep_prob, reuse = False):
    x_reshaped = tf.reshape(x, [-1, 28, 28, 1])
    
    with tf.variable_scope('generator_t', reuse = reuse):
      w_1 = self.weight_variable('w_1', [filter_size, filter_size, 1, n_filters_1])
      b_1 = self.bias_variable('b_1', [n_filters_1])

      # conv
      conv = tf.nn.conv2d(x_reshaped, w_1, strides = [1, 2, 2, 1], padding = 'SAME') + b_1
    
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # relu
      conv = tf.nn.relu(conv)

      # leaky relu
      #conv = tf.maximum(0.2 * conv, conv)

      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
    
      w_2 = self.weight_variable('w_2', [filter_size, filter_size, n_filters_1, n_filters_2])
      b_2 = self.bias_variable('b_2', [n_filters_2])
      
      # conv
      conv = tf.nn.conv2d(conv, w_2, strides = [1, 2, 2, 1], padding = 'SAME') + b_2
    
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # relu
      conv = tf.nn.relu(conv)
      
      # leaky relu
      #conv = tf.maximum(0.2 * conv, conv)

    
      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
    
      conv_flat = tf.reshape(conv, [-1, 7 * 7 * n_filters_2])
    
      w_3 = self.weight_variable('w_3', [7 * 7 * n_filters_2, n_units])
      b_3 = self.bias_variable('b_3', [n_units])
    
      fc = tf.matmul(conv_flat, w_3) + b_3
      
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(fc, [0])
    
      #fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
    
      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)
      
      # relu
      fc = tf.nn.relu(fc)
      
      # leaky relu
      #fc = tf.maximum(0.2 * fc, fc)

      feature = fc
      
    return feature
  
  def classifier(self, x, n_units_1, n_units_2, keep_prob, reuse = False):
    
    with tf.variable_scope('classifier', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_1, n_units_2])
      b_1 = self.bias_variable('b_1', [n_units_2])
    
      fc = tf.matmul(x, w_1) + b_1
      
      # batch norm
      batch_mean, batch_var = tf.nn.moments(fc, [0])
      fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
      
      # relu
      fc = tf.nn.relu(fc)
    
      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)
    
      w_2 = self.weight_variable('w_2', [n_units_2, 10])
      b_2 = self.bias_variable('b_2', [10])

      fc = tf.matmul(fc, w_2) + b_2
      logits = fc
    
    return logits

  def discriminator(self, x, n_units_1, n_units_2, n_units_3, keep_prob, reuse = False):
    
    with tf.variable_scope('discriminator', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_1, n_units_2])
      b_1 = self.bias_variable('b_1', [n_units_2])
    
      fc = tf.matmul(x, w_1) + b_1
      
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(fc, [0])
      #fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
      
      # relu
      #fc = tf.nn.relu(fc)
      
      # leaky relu
      fc = tf.maximum(fc * 0.2, fc)
    
      w_2 = self.weight_variable('w_2', [n_units_2, n_units_3])
      b_2 = self.bias_variable('b_2', [n_units_3])

      fc = tf.matmul(fc, w_2) + b_2
      
      # batch norm
      #batch_mean, batch_var = tf.nn.moments(fc, [0])
      #fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)
      
      # relu
      #fc = tf.nn.relu(fc)
      
      # leaky relu
      fc = tf.maximum(fc * 0.2, fc)
      
      w_3 = self.weight_variable('w_3', [n_units_3, 1])
      b_3 = self.bias_variable('b_3', [1])

      fc = tf.matmul(fc, w_3) + b_3
          
      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)
      
      fc = tf.nn.sigmoid(fc)
    
    return fc

  def loss_cross_entropy(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy
  
  def loss_discriminator(self, prob_1, prob_2):
    return - tf.reduce_mean(tf.log(tf.clip_by_value(prob_1, 1e-10, 1.0)) \
                                +tf.log(tf.clip_by_value(1.0 - prob_2, 1e-10, 1.0))) 
    #return - tf.reduce_mean(0.5 * tf.log(tf.clip_by_value(prob_1, 1e-10, 1.0)) \
    #                            + 0.5 * tf.log(tf.clip_by_value(1.0 - prob_2, 1e-10, 1.0))) 
    #return tf.reduce_mean(tf.square(prob_1 - 1.0)) + tf.reduce_mean(tf.square(prob_2))
  
  def loss_generator(self, prob):
    return - tf.reduce_mean(tf.log(tf.clip_by_value(prob, 1e-10, 1.0)))
    #return tf.reduce_mean(tf.square(prob - 1.0))

  def loss_entropy(self, p):
    entropy = - tf.reduce_mean(tf.reduce_sum(p * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1))
    return entropy

  def loss_mutual_information(self, p):
    p_ave = tf.reduce_mean(p, axis = 0)
    h_y = -tf.reduce_sum(p_ave * tf.log(p_ave + 1e-16))
    h_y_x = - tf.reduce_mean(tf.reduce_sum(p * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1))
    mutual_info = h_y - h_y_x
    return -mutual_info

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    #optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss, var_list = var_list)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, images_train_1, labels_train_1, images_test_1, labels_test_1, \
          images_train_t, labels_train_t, images_test_t, labels_test_t, \
          filter_size, n_filters_1, n_filters_2, n_units_g, n_units_c, \
          n_units_d_1, n_units_d_2, learning_rate, n_iter_1, n_iter_2, \
          batch_size, show_step_1, show_step_2, is_saving, model_path):

    tf.reset_default_graph()

    x_1 = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y_1 = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    x_t = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y_t = tf.placeholder(shape = [None, 10], dtype = tf.float32)
    keep_prob = tf.placeholder(shape = (), dtype = tf.float32)

    feat_1 = self.generator_1(x_1, filter_size, n_filters_1, n_filters_2, n_units_g, \
                            keep_prob, reuse = False)
    feat_t = self.generator_t(x_t, filter_size, n_filters_1, n_filters_2, n_units_g, \
                            keep_prob, reuse = False)
    
    logits_1 = self.classifier(feat_1, n_units_g, n_units_c, keep_prob, reuse = False)
    probs_l_1 = tf.nn.softmax(logits_1)
    loss_1 = self.loss_cross_entropy(probs_l_1, y_1)

    logits_t = self.classifier(feat_t, n_units_g, n_units_c, keep_prob, reuse = True)
    probs_l_t = tf.nn.softmax(logits_t)
    loss_t = self.loss_cross_entropy(probs_l_t, y_t)
    
    prob_d_1 = self.discriminator(feat_1, n_units_g, n_units_d_1, n_units_d_2, \
                                    keep_prob, reuse = False)
    prob_d_t = self.discriminator(feat_t, n_units_g, n_units_d_1, n_units_d_2, \
                                    keep_prob, reuse = True)
    prob_d_1_ave = tf.reduce_mean(prob_d_1)
    prob_d_t_ave = tf.reduce_mean(prob_d_t)
    loss_d = self.loss_discriminator(prob_d_1, prob_d_t)
    loss_g_t = self.loss_generator(prob_d_t)
    
    var_list_g_1 = tf.trainable_variables('generator_1')
    var_list_g_t = tf.trainable_variables('generator_t')
    var_list_c = tf.trainable_variables('classifier')
    var_list_d = tf.trainable_variables('discriminator')
    
    var_list_g_1_c = var_list_g_1 + var_list_c
    var_list_g_t_c = var_list_g_t + var_list_c
    
    # Without Gradient Clipping
    train_step_g_1_c = self.training(loss_1, learning_rate, var_list_g_1_c)
    train_step_g_t_c = self.training(loss_t, learning_rate, var_list_g_t_c)
    train_step_d = self.training(loss_d, learning_rate, var_list_d)
    train_step_g_t = self.training(loss_g_t, learning_rate, var_list_g_t)
    
    # With Gradient Clipping
    #train_step_1 = self.training_clipped(loss_1, learning_rate, 0.1, var_list_0_1)
    #train_step_t = self.training_clipped(loss_t, learning_rate, 0.1, var_list_0_t)
    
    acc_1 =  self.accuracy(probs_l_1, y_1)
    acc_t =  self.accuracy(probs_l_t, y_t)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)
      
      print ('-' * 45)
      print ('Pre-training for generator_1, generator_t and classifier') 
      print ('-' * 45)

      # for 1
      history_loss_train_1 = []
      history_loss_test_1 = []
      history_acc_train_1 = []
      history_acc_test_1 = []

      for i in range(n_iter_1):
        # Train
        rand_index = np.random.choice(len(images_train_1), size = batch_size)
        x_batch = images_train_1[rand_index]
        y_batch = labels_train_1[rand_index]

        feed_dict = {x_1: x_batch, y_1: y_batch, keep_prob: 1.0}

        sess.run(train_step_g_1_c, feed_dict = feed_dict)

        temp_loss = sess.run(loss_1, feed_dict = feed_dict)
        temp_acc = sess.run(acc_1, feed_dict = feed_dict)

        history_loss_train_1.append(temp_loss)
        history_acc_train_1.append(temp_acc)

        if (i + 1) % show_step_1 == 0:
          print ('-' * 15)
          print ('Iteration: ' + str(i + 1) + '  Loss_1: ' + str(temp_loss) + \
                '  Accuracy_1: ' + str(temp_acc))

        # Test
        rand_index = np.random.choice(len(images_test_1), size = batch_size)
        x_batch = images_test_1[rand_index]
        y_batch = labels_test_1[rand_index]

        feed_dict = {x_1: x_batch, y_1: y_batch, keep_prob: 1.0}

        temp_loss = sess.run(loss_1, feed_dict = feed_dict)
        temp_acc = sess.run(acc_1, feed_dict = feed_dict)

        history_loss_test_1.append(temp_loss)
        history_acc_test_1.append(temp_acc)

      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_1), history_loss_train_1, 'b-', label = 'Train')
      ax1.plot(range(n_iter_1), history_loss_test_1, 'r--', label = 'Test')
      ax1.set_title('Loss_1')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_1), history_acc_train_1, 'b-', label = 'Train')
      ax2.plot(range(n_iter_1), history_acc_test_1, 'r--', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_1')
      ax2.legend(loc = 'lower right')

      plt.show()
      
      # for t  
      history_loss_train_t = []
      history_loss_test_t = []
      history_acc_train_t = []
      history_acc_test_t = []

      for i in range(n_iter_1):
        # Train
        rand_index = np.random.choice(len(images_train_t), size = batch_size)
        x_batch = images_train_t[rand_index]
        y_batch = labels_train_t[rand_index]

        feed_dict = {x_t: x_batch, y_t: y_batch, keep_prob: 1.0}

        sess.run(train_step_g_t_c, feed_dict = feed_dict)

        temp_loss = sess.run(loss_t, feed_dict = feed_dict)
        temp_acc = sess.run(acc_t, feed_dict = feed_dict)

        history_loss_train_t.append(temp_loss)
        history_acc_train_t.append(temp_acc)

        if (i + 1) % show_step_1 == 0:
          print ('-' * 15)
          print ('Iteration: ' + str(i + 1) + '  Loss_t: ' + str(temp_loss) + \
                '  Accuracy_t: ' + str(temp_acc))

        # Test
        rand_index = np.random.choice(len(images_test_t), size = batch_size)
        x_batch = images_test_t[rand_index]
        y_batch = labels_test_t[rand_index]

        feed_dict = {x_t: x_batch, y_t: y_batch, keep_prob: 1.0}

        temp_loss = sess.run(loss_t, feed_dict = feed_dict)
        temp_acc = sess.run(acc_t, feed_dict = feed_dict)

        history_loss_test_t.append(temp_loss)
        history_acc_test_t.append(temp_acc)
        
      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_1), history_loss_train_t, 'b-', label = 'Train')
      ax1.plot(range(n_iter_1), history_loss_test_t, 'r--', label = 'Test')
      ax1.set_title('Loss_t')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_1), history_acc_train_t, 'b-', label = 'Train')
      ax2.plot(range(n_iter_1), history_acc_test_t, 'r--', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_t')
      ax2.legend(loc = 'lower right')

      plt.show()
      
      ###################################################
      print ('-' * 45)
      print ('Adversarial adaptation for generator_t and discriminator') 
      print ('-' * 45)

      history_loss_g_t_train = []
      history_loss_d_train = []
      history_loss_g_t_test = []
      history_loss_d_test = []
      history_prob_d_1_train = []
      history_prob_d_t_train = []
      history_prob_d_1_test = []
      history_prob_d_t_test = []

      for i in range(n_iter_2):
        # Train
        rand_index = np.random.choice(len(images_train_1), size = batch_size)
        x_batch_1 = images_train_1[rand_index]
        y_batch_1 = labels_train_1[rand_index]

        rand_index = np.random.choice(len(images_train_t), size = batch_size)
        x_batch_t = images_train_t[rand_index]
        y_batch_t = labels_train_t[rand_index]

        feed_dict = {x_1: x_batch_1, y_1: y_batch_1, x_t: x_batch_t, y_t: y_batch_t, keep_prob: 1.0}

        train_g = True
        train_d = True
        
        temp_loss_g_t = sess.run(loss_g_t, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)

        #if temp_loss_g_t * 1.5 < temp_loss_d:
        #  train_g = False
        if temp_loss_d * 0.1 < temp_loss_g_t:
          train_d = False
          
        #if i < 100:
        #  train_d = False
          
        if train_g:
          sess.run(train_step_g_t, feed_dict = feed_dict)
          sess.run(train_step_g_t, feed_dict = feed_dict)
        if train_d:
          sess.run(train_step_d, feed_dict = feed_dict)
          
        #sess.run(train_step_g_t, feed_dict = feed_dict)
        #sess.run(train_step_g_t, feed_dict = feed_dict)
        #sess.run(train_step_d, feed_dict = feed_dict)

        temp_prob_d_1 = sess.run(prob_d_1_ave, feed_dict = feed_dict)
        temp_prob_d_t = sess.run(prob_d_t_ave, feed_dict = feed_dict)
        temp_loss_g_t = sess.run(loss_g_t, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)

        history_prob_d_1_train.append(temp_prob_d_1)
        history_prob_d_t_train.append(temp_prob_d_t)
        history_loss_g_t_train.append(temp_loss_g_t)
        history_loss_d_train.append(temp_loss_d)

        if (i + 1) % show_step_2 == 0:
          print ('-' * 15)
          print ('Iteration: ' + str(i + 1) + '  Loss_train_g_t: ' + str(temp_loss_g_t) + \
                '  Loss_train_d: ' + str(temp_loss_d))

        # Test
        rand_index = np.random.choice(len(images_test_1), size = batch_size)
        x_batch_1 = images_test_1[rand_index]
        y_batch_1 = labels_test_1[rand_index]

        rand_index = np.random.choice(len(images_test_t), size = batch_size)
        x_batch_t = images_test_t[rand_index]
        y_batch_t = labels_test_t[rand_index]

        feed_dict = {x_1: x_batch_1, y_1: y_batch_1, x_t: x_batch_t, y_t: y_batch_t, keep_prob: 1.0}

        temp_prob_d_1 = sess.run(prob_d_1_ave, feed_dict = feed_dict)
        temp_prob_d_t = sess.run(prob_d_t_ave, feed_dict = feed_dict)
        temp_loss_g_t = sess.run(loss_g_t, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        
        history_prob_d_1_test.append(temp_prob_d_1)
        history_prob_d_t_test.append(temp_prob_d_t)
        history_loss_g_t_test.append(temp_loss_g_t)
        history_loss_d_test.append(temp_loss_d)
        
      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_2), history_loss_g_t_train, 'b-', label = 'Train')
      ax1.plot(range(n_iter_2), history_loss_g_t_test, 'r--', label = 'Test')
      ax1.set_title('Loss_generator_t')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_2), history_loss_d_train, 'b-', label = 'Train')
      ax2.plot(range(n_iter_2), history_loss_d_test, 'r--', label = 'Test')
      ax2.set_title('Loss_discriminator')
      ax2.legend(loc = 'upper right')
      
      plt.show()

      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax3 = fig.add_subplot(1, 2, 1)
      ax3.plot(range(n_iter_2), history_prob_d_1_train, 'b-', label = 'Train')
      ax3.plot(range(n_iter_2), history_prob_d_1_test, 'r--', label = 'Test')
      ax3.set_title('Probability_d_1')
      ax3.legend(loc = 'lower right')

      ax4 = fig.add_subplot(1, 2, 2)
      ax4.plot(range(n_iter_2), history_prob_d_t_train, 'b-', label = 'Train')
      ax4.plot(range(n_iter_2), history_prob_d_t_test, 'r--', label = 'Test')
      ax4.set_title('Probability_d_t')
      ax4.legend(loc = 'lower right')

      plt.show()
      
      ###################################################
      print ('-' * 45)
      print ('Testing') 
      print ('-' * 45)
      
      # for t  
      history_loss_test_t_2 = []
      history_acc_test_t_2 = []

      for i in range(n_iter_1):
        # Test
        rand_index = np.random.choice(len(images_test_t), size = batch_size)
        x_batch = images_test_t[rand_index]
        y_batch = labels_test_t[rand_index]

        feed_dict = {x_t: x_batch, y_t: y_batch, keep_prob: 1.0}

        temp_loss = sess.run(loss_t, feed_dict = feed_dict)
        temp_acc = sess.run(acc_t, feed_dict = feed_dict)

        history_loss_test_t_2.append(temp_loss)
        history_acc_test_t_2.append(temp_acc)
        
      print ('-'* 15)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      #ax1.plot(range(n_iter_1), history_loss_train_t, 'b-', label = 'Train')
      ax1.plot(range(n_iter_1), history_loss_test_t_2, 'r--', label = 'Test')
      ax1.set_title('Loss_t')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      #ax2.plot(range(n_iter_1), history_acc_train_t, 'b-', label = 'Train')
      ax2.plot(range(n_iter_1), history_acc_test_t_2, 'r--', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_t')
      ax2.legend(loc = 'lower right')

      plt.show()
      
      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

##Parameters

filter_size = 3
n_filters_1 = 32
n_filters_2 = 32
n_units_g = 256
n_units_c = 128
n_units_d_1 = 64
n_units_d_2 = 64
learning_rate = 0.001
n_iter_1 = 100
n_iter_2 = 300
batch_size = 64
show_step_1 = 50
show_step_2 = 100
model_path = 'datalab/model'

##Output

images_train_1 = images_train_0
labels_train_1 = labels_train_0
images_test_1 = images_test_0
labels_test_1 = labels_test_0

images_train_t = images_train_60
labels_train_t = labels_train_60
images_test_t = images_test_60
labels_test_t = labels_test_60

is_saving = False

adda.fit(images_train_1, labels_train_1, images_test_1, labels_test_1, \
        images_train_t, labels_train_t, images_test_t, labels_test_t, \
        filter_size, n_filters_1, n_filters_2, n_units_g, n_units_c, n_units_d_1, n_units_d_2, \
        learning_rate, n_iter_1, n_iter_2, batch_size, show_step_1, show_step_2, is_saving, model_path)

image.png

image.png

image.png

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?