LoginSignup
2

More than 5 years have passed since last update.

Cross-Gradient (CrossGrad) Training の実装に関するメモ

Last updated at Posted at 2018-08-26

Reference

Generalizing Across Domains via Cross-Gradient Training

image.png

Data

images_train = mnist.train.images
labels_train = mnist.train.labels
images_test = mnist.test.images
labels_test = mnist.test.labels

import skimage.transform

# for train data
indices = np.random.choice(55000, 10000, replace = False)

images_train_0 = images_train[indices]
images_train_0_2d = np.reshape(images_train_0, (-1, 28, 28))
labels_train_0 = labels_train[indices]

#images_train_flip_2d = images_train_0_2d[:, :, ::-1]
#images_train_flip = np.reshape(images_train_0_2d[:, :, ::-1], (-1, 28*28))
#labels_train_flip = labels_train[indices]

images_train_30_2d = []
for i in range(len(images_train_0)):
  images_train_30_2d.append(skimage.transform.rotate(images_train_0_2d[i], 30))
images_train_30 = np.reshape(images_train_30_2d, (-1, 28*28))
labels_train_30 = labels_train[indices] 

images_train_60_2d = []
for i in range(len(images_train_0)):
  images_train_60_2d.append(skimage.transform.rotate(images_train_0_2d[i], 60))
images_train_60 = np.reshape(images_train_60_2d, (-1, 28*28))
labels_train_60 = labels_train[indices] 

#images_train_90_2d = []
#for i in range(len(images_train_0)):
#  images_train_90_2d.append(skimage.transform.rotate(images_train_0_2d[i], 90))
#images_train_90 = np.reshape(images_train_90_2d, (-1, 28*28))
#labels_train_90 = labels_train[indices] 

#images_train_180_2d = []
#for i in range(len(images_train_0)):
#  images_train_180_2d.append(skimage.transform.rotate(images_train_0_2d[i], 180))
#images_train_180 = np.reshape(images_train_180_2d, (-1, 28*28))
#labels_train_180 = labels_train[indices] 

import skimage.transform

# for test data
indices = np.random.choice(10000, 10000, replace = False)

images_test_0 = images_test[indices]
images_test_0_2d = np.reshape(images_test_0, (-1, 28, 28))
labels_test_0 = labels_test[indices]

#images_test_flip_2d = images_test_0_2d[:, :, ::-1]
#images_test_flip = np.reshape(images_test_flip_2d, (-1, 28*28))
#labels_test_flip = labels_test[indices]

images_test_30_2d = []
for i in range(len(images_test_0)):
  images_test_30_2d.append(skimage.transform.rotate(images_test_0_2d[i], 30))
images_test_30 = np.reshape(images_test_30_2d, (-1, 28*28))
labels_test_30 = labels_test[indices] 

images_test_60_2d = []
for i in range(len(images_test_0)):
  images_test_60_2d.append(skimage.transform.rotate(images_test_0_2d[i], 60))
images_test_60 = np.reshape(images_test_60_2d, (-1, 28*28))
labels_test_60 = labels_test[indices] 

#images_test_90_2d = []
#for i in range(len(images_test_0)):
#  images_test_90_2d.append(skimage.transform.rotate(images_test_0_2d[i], 90))
#images_test_90 = np.reshape(images_test_90_2d, (-1, 28*28))
#labels_test_90 = labels_test[indices] 

#images_test_180_2d = []
#for i in range(len(images_test_0)):
#  images_test_180_2d.append(skimage.transform.rotate(images_test_0_2d[i], 180))
#images_test_180 = np.reshape(images_test_180_2d, (-1, 28*28))
#labels_test_180 = labels_test[indices] 

index = np.random.randint(1000)

# for train data
fig = plt.figure(figsize = (7, 5))

ax = fig.add_subplot(1, 6, 1)
ax.imshow(np.reshape(images_train_0[index], (28, 28)), cmap = 'gray')
ax.set_title('image_0')
ax.set_axis_off()

ax = fig.add_subplot(1, 6, 2)
ax.imshow(np.reshape(images_train_30[index], (28, 28)), cmap = 'gray')
ax.set_title('image_30')
ax.set_axis_off()

ax = fig.add_subplot(1, 6, 3)
ax.imshow(np.reshape(images_train_60[index], (28, 28)), cmap = 'gray')
ax.set_title('image_60')
ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 4)
#ax.imshow(np.reshape(images_train_90[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_90')
#ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 5)
#ax.imshow(np.reshape(images_train_180[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_180')
#ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 6)
#ax.imshow(np.reshape(images_train_flip[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_flip')
#ax.set_axis_off()

plt.show()

# for test data
fig = plt.figure(figsize = (7, 5))

ax = fig.add_subplot(1, 6, 1)
ax.imshow(np.reshape(images_test_0[index], (28, 28)), cmap = 'gray')
ax.set_title('image_0')
ax.set_axis_off()

ax = fig.add_subplot(1, 6, 2)
ax.imshow(np.reshape(images_test_30[index], (28, 28)), cmap = 'gray')
ax.set_title('image_30')
ax.set_axis_off()

ax = fig.add_subplot(1, 6, 3)
ax.imshow(np.reshape(images_test_60[index], (28, 28)), cmap = 'gray')
ax.set_title('image_60')
ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 4)
#ax.imshow(np.reshape(images_test_90[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_90')
#ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 5)
#ax.imshow(np.reshape(images_test_180[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_180')
#ax.set_axis_off()

#ax = fig.add_subplot(1, 6, 6)
#ax.imshow(np.reshape(images_test_flip[index], (28, 28)), cmap = 'gray')
#ax.set_title('image_flip')
#ax.set_axis_off()

plt.show()

image.png

Sample Code

# Cross-Gradient Training

class CrossGrad():
  def __init__(self):
    pass

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def alpha_variable(self, name):
    initializer = tf.constant_initializer(value = 0.75, dtype = tf.float32)
    return tf.get_variable(name, shape = (), initializer = initializer)

  def generator(self, x, filter_size, n_filters_1, n_filters_2, n_units, keep_prob, reuse = False):
    x_reshaped = tf.reshape(x, [-1, 28, 28, 1])

    with tf.variable_scope('generator', reuse = reuse):
      w_1 = self.weight_variable('w_1', [filter_size, filter_size, 1, n_filters_1])
      b_1 = self.bias_variable('b_1', [n_filters_1])

      # conv
      conv = tf.nn.conv2d(x_reshaped, w_1, strides = [1, 2, 2, 1], padding = 'SAME') + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      conv = tf.nn.relu(conv)

      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

      w_2 = self.weight_variable('w_2', [filter_size, filter_size, n_filters_1, n_filters_2])
      b_2 = self.bias_variable('b_2', [n_filters_2])

      # conv
      conv = tf.nn.conv2d(conv, w_2, strides = [1, 2, 2, 1], padding = 'SAME') + b_2

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(conv, [0, 1, 2])
      #conv = (conv - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      conv = tf.nn.relu(conv)

      # max_pool
      #conv = tf.nn.max_pool(conv, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

      conv_flat = tf.reshape(conv, [-1, 7 * 7 * n_filters_2])

      w_3 = self.weight_variable('w_3', [7 * 7 * n_filters_2, n_units])
      b_3 = self.bias_variable('b_3', [n_units])

      fc = tf.matmul(conv_flat, w_3) + b_3

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(fc, [0])

      #fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)

      # relu
      fc = tf.nn.relu(fc)

      # leaky relu
      #fc = tf.maximum(0.2 * fc, fc)

      feature = fc

    return feature

  def classifier_d(self, x, n_units_g, n_units_d, n_domains, keep_prob, reuse = False):

    with tf.variable_scope('classifier_d', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_g, n_units_d])
      b_1 = self.bias_variable('b_1', [n_units_d])

      fc = tf.matmul(x, w_1) + b_1

      # batch norm
      batch_mean, batch_var = tf.nn.moments(fc, [0])
      fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      fc = tf.nn.relu(fc)

      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_d, n_domains])
      b_2 = self.bias_variable('b_2', [n_domains])

      fc = tf.matmul(fc, w_2) + b_2
      logits = fc

    return logits

  def classifier_l(self, x, n_units_g, n_units_l, n_labels, keep_prob, reuse = False):

    with tf.variable_scope('classifier_l', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_g, n_units_l])
      b_1 = self.bias_variable('b_1', [n_units_l])

      fc = tf.matmul(x, w_1) + b_1

      # batch norm
      batch_mean, batch_var = tf.nn.moments(fc, [0])
      fc = (fc - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      fc = tf.nn.relu(fc)

      # dropout
      #fc = tf.nn.dropout(fc, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_l, n_labels])
      b_2 = self.bias_variable('b_2', [n_labels])

      fc = tf.matmul(fc, w_2) + b_2
      logits = fc

    return logits

  def loss_cross_entropy(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  def loss_entropy(self, p):
    entropy = - tf.reduce_mean(tf.reduce_sum(p * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1))
    return entropy

  def loss_mutual_information(self, p):
    p_ave = tf.reduce_mean(p, axis = 0)
    h_y = -tf.reduce_sum(p_ave * tf.log(p_ave + 1e-16))
    h_y_x = - tf.reduce_mean(tf.reduce_sum(p * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1))
    mutual_info = h_y - h_y_x
    return -mutual_info

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_2(self, loss, learning_rate, var_list):
    #optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss, var_list = var_list)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, images_train_1, labels_train_1, images_test_1, labels_test_1, \
          images_train_2, labels_train_2, images_test_2, labels_test_2, \
          images_train_3, labels_train_3, images_test_3, labels_test_3, \
          filter_size, n_filters_1, n_filters_2, n_units_g, \
          n_units_d, n_domains, n_units_l, n_labels, eps, alpha_d, alpha_l, \
          learning_rate, n_iter_1, n_iter_2, batch_size, show_step_1, show_step_2, is_saving, model_path):

    tf.reset_default_graph()

    x = tf.placeholder(shape = [None, 28 * 28], dtype = tf.float32)
    y = tf.placeholder(shape = [None, n_labels], dtype = tf.float32)
    d = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    keep_prob = tf.placeholder(shape = (), dtype = tf.float32)

    # for initialization
    feat = self.generator(x, filter_size, n_filters_1, n_filters_2, n_units_g, \
                            keep_prob, reuse = False)

    logits_d = self.classifier_d(feat, n_units_g, n_units_d, n_domains, keep_prob, reuse = False)
    probs_d = tf.nn.softmax(logits_d)
    loss_d = self.loss_cross_entropy(probs_d, d)

    logits_l = self.classifier_l(feat, n_units_g, n_units_l, n_labels, keep_prob, reuse = False)
    probs_l = tf.nn.softmax(logits_l)
    loss_l = self.loss_cross_entropy(probs_l, y)

    var_list_g = tf.trainable_variables('generator')
    var_list_c_d = tf.trainable_variables('classifier_d')
    var_list_c_l = tf.trainable_variables('classifier_l')

    var_list_d = var_list_g + var_list_c_d
    var_list_l = var_list_g + var_list_c_l

    # Without Gradient Clipping
    train_step_d = self.training(loss_d, learning_rate, var_list_d)
    train_step_l = self.training(loss_l, learning_rate, var_list_l)

    acc_d =  self.accuracy(probs_d, d)
    acc_l =  self.accuracy(probs_l, y)

    # for training
    feat_1 = self.generator(x, filter_size, n_filters_1, n_filters_2, n_units_g, \
                            keep_prob, reuse = True)

    logits_d_1 = self.classifier_d(feat_1, n_units_g, n_units_d, n_domains, keep_prob, reuse = True)
    probs_d_1 = tf.nn.softmax(logits_d_1)
    loss_d_1 = self.loss_cross_entropy(probs_d_1, d)

    logits_l_1 = self.classifier_l(feat_1, n_units_g, n_units_l, n_labels, keep_prob, reuse = True)
    probs_l_1 = tf.nn.softmax(logits_l_1)
    loss_l_1 = self.loss_cross_entropy(probs_l_1, y)

    grad_d = tf.stop_gradient(tf.gradients(loss_d_1, [x])[0])
    grad_norm_d = grad_d / (tf.reshape(tf.sqrt(tf.reduce_sum(tf.pow(grad_d, 2), axis = 1)), [-1, 1]) + 1e-10)
    x_d = x + eps * grad_norm_d

    grad_l = tf.stop_gradient(tf.gradients(loss_l_1, [x])[0])
    grad_norm_l = grad_l / (tf.reshape(tf.sqrt(tf.reduce_sum(tf.pow(grad_l, 2), axis = 1)), [-1, 1]) + 1e-10)
    x_l = x + eps * grad_norm_l

    feat_l = self.generator(x_l, filter_size, n_filters_1, n_filters_2, n_units_g, \
                      keep_prob, reuse = True)
    feat_d = self.generator(x_d, filter_size, n_filters_1, n_filters_2, n_units_g, \
                      keep_prob, reuse = True)

    logits_d_2 = self.classifier_d(feat_l, n_units_g, n_units_d, n_domains, keep_prob, reuse = True)
    probs_d_2 = tf.nn.softmax(logits_d_2)
    loss_d_2 = self.loss_cross_entropy(probs_d_2, d)

    logits_l_2 = self.classifier_l(feat_d, n_units_g, n_units_l, n_labels, keep_prob, reuse = True)
    probs_l_2 = tf.nn.softmax(logits_l_2)
    loss_l_2 = self.loss_cross_entropy(probs_l_2, y)

    loss_cg_d = (1.0 - alpha_d) * loss_d_1 + alpha_d * loss_d_2
    loss_cg_l = (1.0 - alpha_l) * loss_l_1 + alpha_l * loss_l_2

    #var_list_cg_d = var_list_g + var_list_c_d
    #var_list_cg_l = var_list_g + var_list_c_l

    # Without Gradient Clipping
    train_step_cg_d = self.training_2(loss_cg_d, learning_rate, var_list_d)
    train_step_cg_l = self.training_2(loss_cg_l, learning_rate, var_list_l)
    #train_step_cg_d = self.training(loss_cg_d, learning_rate, var_list_cg_d)
    #train_step_cg_l = self.training(loss_cg_l, learning_rate, var_list_cg_l)

    acc_cg_d_1 =  self.accuracy(probs_d_1, d)
    acc_cg_d_2 =  self.accuracy(probs_d_2, d)
    acc_cg_l_1 =  self.accuracy(probs_l_1, y)
    acc_cg_l_2 =  self.accuracy(probs_l_2, y)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)

      ##########################################################################
      print ('-' * 30)
      print ('Initialization ') 
      print ('-' * 30)

      history_loss_d_train = []
      history_loss_d_test = []
      history_acc_d_train = []
      history_acc_d_test = []

      history_loss_l_train = []
      history_loss_l_test = []
      history_acc_l_train = []
      history_acc_l_test = []

      history_loss_l_3_test = []
      history_acc_l_3_test = []

      for i in range(n_iter_1):
        # Training data of 1 and 2
        rand_index = np.random.choice(len(images_train_1), size = batch_size//2)
        x_batch_1 = images_train_1[rand_index]
        y_batch_1 = labels_train_1[rand_index]
        d_1 = np.ones(shape = (batch_size//2), dtype = np.int32) * 0
        d_1_one_hot = np.identity(2)[d_1]

        rand_index = np.random.choice(len(images_train_2), size = batch_size//2)
        x_batch_2 = images_train_2[rand_index]
        y_batch_2 = labels_train_2[rand_index]
        d_2 = np.ones(shape = (batch_size//2), dtype = np.int32) * 1
        d_2_one_hot = np.identity(2)[d_2]

        x_batch = np.concatenate((x_batch_1, x_batch_2), axis = 0)
        y_batch = np.concatenate((y_batch_1, y_batch_2), axis = 0)
        d_batch = np.concatenate((d_1_one_hot, d_2_one_hot), axis = 0).astype(np.float32)

        perm = np.random.permutation(batch_size)
        x_batch_p = x_batch[perm]
        y_batch_p = y_batch[perm]
        d_batch_p = d_batch[perm]

        feed_dict = {x: x_batch_p, y: y_batch_p, d: d_batch_p, keep_prob: 1.0}

        sess.run(train_step_d, feed_dict = feed_dict)
        sess.run(train_step_l, feed_dict = feed_dict)

        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_loss_l = sess.run(loss_l, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)
        temp_acc_l = sess.run(acc_l, feed_dict = feed_dict)

        history_loss_d_train.append(temp_loss_d)
        history_loss_l_train.append(temp_loss_l)
        history_acc_d_train.append(temp_acc_d)
        history_acc_l_train.append(temp_acc_l)

        if (i + 1) % show_step_1 == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_d: ' + str(temp_loss_d) + '  Accuracy_d: ' + str(temp_acc_d) + \
                '  Loss_l: ' + str(temp_loss_l) + '  Accuracy_l: ' + str(temp_acc_l))

        # Test data of 1 and 2
        rand_index = np.random.choice(len(images_test_1), size = batch_size//2)
        x_batch_1 = images_test_1[rand_index]
        y_batch_1 = labels_test_1[rand_index]
        d_1 = np.ones(shape = (batch_size//2), dtype = np.int32) * 0
        d_1_one_hot = np.identity(2)[d_1]

        rand_index = np.random.choice(len(images_test_2), size = batch_size//2)
        x_batch_2 = images_test_2[rand_index]
        y_batch_2 = labels_test_2[rand_index]
        d_2 = np.ones(shape = (batch_size//2), dtype = np.int32) * 1
        d_2_one_hot = np.identity(2)[d_2]

        x_batch = np.concatenate((x_batch_1, x_batch_2), axis = 0)
        y_batch = np.concatenate((y_batch_1, y_batch_2), axis = 0)
        d_batch = np.concatenate((d_1_one_hot, d_2_one_hot), axis = 0).astype(np.float32)

        perm = np.random.permutation(batch_size)
        x_batch_p = x_batch[perm]
        y_batch_p = y_batch[perm]
        d_batch_p = d_batch[perm]

        feed_dict = {x: x_batch_p, y: y_batch_p, d: d_batch_p, keep_prob: 1.0}

        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_loss_l = sess.run(loss_l, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)
        temp_acc_l = sess.run(acc_l, feed_dict = feed_dict)

        history_loss_d_test.append(temp_loss_d)
        history_loss_l_test.append(temp_loss_l)
        history_acc_d_test.append(temp_acc_d)
        history_acc_l_test.append(temp_acc_l)

        # Test data of 3
        rand_index = np.random.choice(len(images_test_3), size = batch_size)
        x_batch = images_test_3[rand_index]
        y_batch = labels_test_3[rand_index]

        feed_dict = {x: x_batch, y: y_batch, keep_prob: 1.0}

        temp_loss_l = sess.run(loss_l, feed_dict = feed_dict)
        temp_acc_l = sess.run(acc_l, feed_dict = feed_dict)

        history_loss_l_3_test.append(temp_loss_l)
        history_acc_l_3_test.append(temp_acc_l)

      print ('-' * 100)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_1), history_loss_d_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter_1), history_loss_d_test, 'r-', label = 'Test')
      ax1.set_title('Loss_d')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_1), history_acc_d_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter_1), history_acc_d_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_d')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_1), history_loss_l_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter_1), history_loss_l_test, 'r-', label = 'Test')
      ax1.set_title('Loss_l')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_1), history_acc_l_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter_1), history_acc_l_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_l')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_1), history_loss_l_3_test, 'r-', label = 'Test')
      ax1.set_title('Loss_l_3')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_1), history_acc_l_3_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_l_3')
      ax2.legend(loc = 'lower right')

      plt.show()

      ##########################################################################
      print ('-' * 30)
      print ('Training ') 
      print ('-' * 30)

      history_loss_cg_d_train = []
      history_loss_cg_d_test = []
      history_acc_cg_d_1_train = []
      history_acc_cg_d_1_test = []
      history_acc_cg_d_2_train = []
      history_acc_cg_d_2_test = []

      history_loss_cg_l_train = []
      history_loss_cg_l_test = []
      history_acc_cg_l_1_train = []
      history_acc_cg_l_1_test = []
      history_acc_cg_l_2_train = []
      history_acc_cg_l_2_test = []

      history_loss_cg_l_3_test = []
      history_acc_cg_l_3_test = []

      for i in range(n_iter_2):
        # Training data of 1 and 2
        rand_index = np.random.choice(len(images_train_1), size = batch_size//2)
        x_batch_1 = images_train_1[rand_index]
        y_batch_1 = labels_train_1[rand_index]
        d_1 = np.ones(shape = (batch_size//2), dtype = np.int32) * 0
        d_1_one_hot = np.identity(2)[d_1]

        rand_index = np.random.choice(len(images_train_2), size = batch_size//2)
        x_batch_2 = images_train_2[rand_index]
        y_batch_2 = labels_train_2[rand_index]
        d_2 = np.ones(shape = (batch_size//2), dtype = np.int32) * 1
        d_2_one_hot = np.identity(2)[d_2]

        x_batch = np.concatenate((x_batch_1, x_batch_2), axis = 0)
        y_batch = np.concatenate((y_batch_1, y_batch_2), axis = 0)
        d_batch = np.concatenate((d_1_one_hot, d_2_one_hot), axis = 0).astype(np.float32)

        perm = np.random.permutation(batch_size)
        x_batch_p = x_batch[perm]
        y_batch_p = y_batch[perm]
        d_batch_p = d_batch[perm]

        feed_dict = {x: x_batch_p, y: y_batch_p, d: d_batch_p, keep_prob: 1.0}

        sess.run(train_step_cg_d, feed_dict = feed_dict)
        sess.run(train_step_cg_l, feed_dict = feed_dict)

        temp_loss_cg_d = sess.run(loss_cg_d, feed_dict = feed_dict)
        temp_loss_cg_l = sess.run(loss_cg_l, feed_dict = feed_dict)
        temp_acc_cg_d_1 = sess.run(acc_cg_d_1, feed_dict = feed_dict)
        temp_acc_cg_l_1 = sess.run(acc_cg_l_1, feed_dict = feed_dict)
        temp_acc_cg_d_2 = sess.run(acc_cg_d_2, feed_dict = feed_dict)
        temp_acc_cg_l_2 = sess.run(acc_cg_l_2, feed_dict = feed_dict)

        history_loss_cg_d_train.append(temp_loss_cg_d)
        history_loss_cg_l_train.append(temp_loss_cg_l)
        history_acc_cg_d_1_train.append(temp_acc_cg_d_1)
        history_acc_cg_d_2_train.append(temp_acc_cg_d_2)
        history_acc_cg_l_1_train.append(temp_acc_cg_l_1)
        history_acc_cg_l_2_train.append(temp_acc_cg_l_2)

        if (i + 1) % show_step_2 == 0:
          print ('-' * 100)
          #print ('Iteration: ' + str(i + 1) + \
          #       '  Loss_cg_d: ' + str(temp_loss_d) + '  Accuracy_cg_d: ' + str(temp_acc_d) + \
          #      '  Loss_cg_l: ' + str(temp_loss_l) + '  Accuracy_cg_l: ' + str(temp_acc_l))
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_cg_d: ' + str(temp_loss_cg_d) + '  Loss_cg_l: ' + str(temp_loss_cg_l))

        # Test data of 1 and 2
        rand_index = np.random.choice(len(images_test_1), size = batch_size//2)
        x_batch_1 = images_test_1[rand_index]
        y_batch_1 = labels_test_1[rand_index]
        d_1 = np.ones(shape = (batch_size//2), dtype = np.int32) * 0
        d_1_one_hot = np.identity(2)[d_1]

        rand_index = np.random.choice(len(images_test_2), size = batch_size//2)
        x_batch_2 = images_test_2[rand_index]
        y_batch_2 = labels_test_2[rand_index]
        d_2 = np.ones(shape = (batch_size//2), dtype = np.int32) * 1
        d_2_one_hot = np.identity(2)[d_2]

        x_batch = np.concatenate((x_batch_1, x_batch_2), axis = 0)
        y_batch = np.concatenate((y_batch_1, y_batch_2), axis = 0)
        d_batch = np.concatenate((d_1_one_hot, d_2_one_hot), axis = 0).astype(np.float32)

        perm = np.random.permutation(batch_size)
        x_batch_p = x_batch[perm]
        y_batch_p = y_batch[perm]
        d_batch_p = d_batch[perm]

        feed_dict = {x: x_batch_p, y: y_batch_p, d: d_batch_p, keep_prob: 1.0}

        temp_loss_cg_d = sess.run(loss_cg_d, feed_dict = feed_dict)
        temp_loss_cg_l = sess.run(loss_cg_l, feed_dict = feed_dict)
        temp_acc_cg_d_1 = sess.run(acc_cg_d_1, feed_dict = feed_dict)
        temp_acc_cg_d_2 = sess.run(acc_cg_d_2, feed_dict = feed_dict)
        temp_acc_cg_l_1 = sess.run(acc_cg_l_1, feed_dict = feed_dict)
        temp_acc_cg_l_2 = sess.run(acc_cg_l_2, feed_dict = feed_dict)

        history_loss_cg_d_test.append(temp_loss_cg_d)
        history_loss_cg_l_test.append(temp_loss_cg_l)
        history_acc_cg_d_1_test.append(temp_acc_cg_d_1)
        history_acc_cg_d_2_test.append(temp_acc_cg_d_2)
        history_acc_cg_l_1_test.append(temp_acc_cg_l_1)
        history_acc_cg_l_2_test.append(temp_acc_cg_l_2)

        # Test data of 3
        rand_index = np.random.choice(len(images_test_3), size = batch_size)
        x_batch = images_test_3[rand_index]
        y_batch = labels_test_3[rand_index]

        feed_dict = {x: x_batch, y: y_batch, keep_prob: 1.0}

        temp_loss_l = sess.run(loss_l, feed_dict = feed_dict)
        temp_acc_l = sess.run(acc_l, feed_dict = feed_dict)

        history_loss_cg_l_3_test.append(temp_loss_l)
        history_acc_cg_l_3_test.append(temp_acc_l)

      print ('-' * 100)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_2), history_loss_cg_d_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter_2), history_loss_cg_d_test, 'r-', label = 'Test')
      ax1.set_title('Loss_cg_d')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_2), history_loss_cg_l_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter_2), history_loss_cg_l_test, 'r-', label = 'Test')
      ax2.set_title('Loss_cg_l')
      ax2.legend(loc = 'upper right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_2), history_acc_cg_d_1_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter_2), history_acc_cg_d_1_test, 'r-', label = 'Test')
      ax1.set_ylim(0.0, 1.0)
      ax1.set_title('Accuracy_cg_d_1')
      ax1.legend(loc = 'lower right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_2), history_acc_cg_l_1_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter_2), history_acc_cg_l_1_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_cg_l_1')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_2), history_acc_cg_d_2_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter_2), history_acc_cg_d_2_test, 'r-', label = 'Test')
      ax1.set_ylim(0.0, 1.0)
      ax1.set_title('Accuracy_cg_d_2')
      ax1.legend(loc = 'lower right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_2), history_acc_cg_l_2_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter_2), history_acc_cg_l_2_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_cg_l_2')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter_2), history_loss_cg_l_3_test, 'r-', label = 'Test')
      ax1.set_title('Loss_cg_l_3')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter_2), history_acc_cg_l_3_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_cg_l_3')
      ax2.legend(loc = 'lower right')

      plt.show()

      if is_saving:
        model_path = saver.save(sess, model_path)
        print ('done saving at ', model_path)

Parameters

filter_size = 3
n_filters_1 = 64
n_filters_2 = 64
n_units_g = 128
n_units_d = 128
n_domains = 2
n_units_l = 128
n_labels = 10
eps = 1.0
alpha_d = 0.5
alpha_l = 0.5
learning_rate = 0.001
n_iter_1 = 300
n_iter_2 = 300
batch_size = 64
show_step_1 = 100
show_step_2 = 100
model_path = 'datalab/model'

Output

images_train_1 = images_train_0
labels_train_1 = labels_train_0
images_test_1 = images_test_0
labels_test_1 = labels_test_0

images_train_2 = images_train_60
labels_train_2 = labels_train_60
images_test_2 = images_test_60
labels_test_2 = labels_test_60

images_train_3 = images_train_30
labels_train_3 = labels_train_30
images_test_3 = images_test_30
labels_test_3 = labels_test_30

is_saving = False

cg.fit(images_train_1, labels_train_1, images_test_1, labels_test_1, \
       images_train_2, labels_train_2, images_test_2, labels_test_2, \
       images_train_3, labels_train_3, images_test_3, labels_test_3, \
       filter_size, n_filters_1, n_filters_2, n_units_g, \
       n_units_d, n_domains, n_units_l, n_labels, eps, alpha_d, alpha_l, \
       learning_rate, n_iter_1, n_iter_2, batch_size, show_step_1, show_step_2, is_saving, model_path)

image.png

image.png

image.png

image.png

image.png

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2