LoginSignup
0
0

More than 5 years have passed since last update.

Semi-supervised Domain Adaptation (Adversarial Residual Transform Network) の実装に関するメモ

Last updated at Posted at 2018-09-01

Reference

Adversarial Residual Transform Network (ARTN) の実装に関するメモ

Data

# generate toy data

n_s = 1000
n_t = 1000
n_classes = 3

p_s = [0.4, 0.3, 0.3]
p_t = [0.4, 0.3, 0.3]

mu_s_1 = [0.0, 0.0]
sig_s_1 = [0.2, 0.2]
mu_s_2 = [0.0, 1.0]
sig_s_2 = sig_s_1
mu_s_3 = [0.5, 0.5]
sig_s_3 = sig_s_1

mu_t_1 = [2.8, 1.3]
sig_t_1 = sig_s_1
mu_t_2 = [1.7, 0.3]
sig_t_2 = sig_s_1
mu_t_3 = [2.0, 1.0]
sig_t_3 = sig_s_1

y_s = np.random.choice(n_classes, n_s, p = p_s)
y_s_one_hot = np.identity(n_classes)[y_s].astype(np.int32)
y_t = np.random.choice(n_classes, n_t, p = p_t)
y_t_one_hot = np.identity(n_classes)[y_t].astype(np.int32)

x_s = []
for label in y_s:
  if label == 0:
    x = np.random.normal(loc = mu_s_1, scale = sig_s_1)
  elif label == 1:
    x = np.random.normal(loc = mu_s_2, scale = sig_s_2)
  else:
    x = np.random.normal(loc = mu_s_3, scale = sig_s_3)
  x_s.append(x)
x_s = np.reshape(x_s, (-1, 2))

x_t = []
for label in y_t:
  if label == 0:
    x = np.random.normal(loc = mu_t_1, scale = sig_t_1)
  elif label == 1:
    x = np.random.normal(loc = mu_t_2, scale = sig_t_2)
  else:
    x = np.random.normal(loc = mu_t_3, scale = sig_t_3)
  x_t.append(x)
x_t = np.reshape(x_t, (-1, 2))

plt.scatter(x_s[:,0], x_s[:,1], c=y_s, cmap = 'RdYlGn', alpha = 0.5)
plt.scatter(x_t[:,0], x_t[:,1], c=y_t, cmap = 'RdYlGn')
plt.show()

n_domains = 2
d_s = np.ones(shape = [n_s], dtype = np.int32) * 0
d_s_one_hot = np.identity(n_domains)[d_s]
d_t = np.ones(shape = [n_t], dtype = np.int32) * 1
d_t_one_hot = np.identity(n_domains)[d_t]

y_t_l = np.array([0, 1, 2])
y_t_l_one_hot = np.identity(n_classes)[y_t_l].astype(np.int32)

x_t_l = []
for label in y_t_l:
  if label == 0:
    x = np.random.normal(loc = mu_t_1, scale = sig_t_1)
  elif label == 1:
    x = np.random.normal(loc = mu_t_2, scale = sig_t_2)
  else:
    x = np.random.normal(loc = mu_t_3, scale = sig_t_3)
  x_t_l.append(x)
x_t_l = np.reshape(x_t_l, (-1, 2))

plt.scatter(x_t_l[:,0], x_t_l[:,1], c=y_t_l, cmap = 'RdYlGn')
plt.xlim(0.5, 3.5)
plt.ylim(-0.5, 2.0)
plt.show()

n_s_train = np.int(n_s * 0.8)
n_t_train = np.int(n_t * 0.8)

x_s_train = x_s[:n_s_train]
x_s_test = x_s[n_s_train:]
y_s_train = y_s_one_hot[:n_s_train]
y_s_test = y_s_one_hot[n_s_train:]
d_s_train = d_s_one_hot[:n_s_train]
d_s_test = d_s_one_hot[n_s_train:]

x_t_train = x_t[:n_t_train]
x_t_test = x_t[n_t_train:]
y_t_train = y_t_one_hot[:n_t_train]
y_t_test = y_t_one_hot[n_t_train:]
d_t_train = d_t_one_hot[:n_t_train]
d_t_test = d_t_one_hot[n_t_train:]

x_t_l_train = x_t_l[:]
y_t_l_train = y_t_l_one_hot[:] 

image.png

image.png

Sample Code

# Semi-supervised Domain Adaptation

class SemiDA():
  def __init__(self):
    pass

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def f_extractor(self, x, n_in, n_units, keep_prob, reuse = False):
    with tf.variable_scope('f_extractor', reuse = reuse):
      w = self.weight_variable('w', [n_in, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(x, w) + b

      # batch norm
      batch_mean, batch_var = tf.nn.moments(f, [0])
      f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def transform(self, x, n_in, n_units, keep_prob, reuse = False):
    f_s = self.f_extractor(x, n_in, n_units, keep_prob, reuse = reuse)

    with tf.variable_scope('transform', reuse = reuse):
      w = self.weight_variable('w', [n_units, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(f_s, w) + b

      # residual connection
      f += f_s 

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(f, [0])
      #f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def classifier_d(self, x, n_units_f, n_units_d, n_domains, keep_prob, reuse = False):

    with tf.variable_scope('classifier_d', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_d])
      b_1 = self.bias_variable('b_1', [n_units_d])

      d = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(d, [0])
      #d = (d - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      d = tf.nn.relu(d)

      # dropout
      #d = tf.nn.dropout(d, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_d, n_domains])
      b_2 = self.bias_variable('b_2', [n_domains])

      d = tf.matmul(d, w_2) + b_2
      logits = d

    return logits

  def classifier_c(self, x, n_units_f, n_units_c, n_classes, keep_prob, reuse = False):

    with tf.variable_scope('classifier_c', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_c])
      b_1 = self.bias_variable('b_1', [n_units_c])

      l = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(l, [0])
      #l = (l - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      l = tf.nn.relu(l)

      # dropout
      #l = tf.nn.dropout(l, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_c, n_classes])
      b_2 = self.bias_variable('b_2', [n_classes])

      l = tf.matmul(l, w_2) + b_2
      logits = l

    return logits

  def predictor(self, x, n_units_f, n_units_p, n_classes, keep_prob, reuse = False):

    with tf.variable_scope('predictor', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_p])
      b_1 = self.bias_variable('b_1', [n_units_p])

      l = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(l, [0])
      #l = (l - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      l = tf.nn.relu(l)

      # dropout
      #l = tf.nn.dropout(l, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_p, n_classes])
      b_2 = self.bias_variable('b_2', [n_classes])

      l = tf.matmul(l, w_2) + b_2
      logits = l

    return logits

  def gradient_reversal(self, f, lam, n_units_f, batch_size):
    i = tf.ones(shape = [batch_size, n_units_f], dtype = tf.float32)

    return - lam * i + tf.stop_gradient(f + lam * i)

  def loss_cross_entropy(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  def loss_entropy(self, p):
    entropy = tf.reduce_mean(tf.reduce_sum(p * tf.log(tf.clip_by_value(p, 1e-10, 1.0)), axis = 1))
    return entropy

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_gd(self, loss, learning_rate, var_list):
    #optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss, var_list = var_list)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, x_s_train, x_s_test, y_s_train, y_s_test, \
          x_t_train, x_t_test, y_t_train, y_t_test, \
          x_t_l_train, y_t_l_train, n_in, n_units_f, n_units_d, n_domains, \
          n_units_c, n_classes, n_units_p, lam, \
          learning_rate, n_iter, batch_size, show_step, is_saving, model_path):

    tf.reset_default_graph()

    x_s = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_s = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_s = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    x_t = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_t = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)    
    x_t_l = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_t_l = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_t = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    keep_prob = tf.placeholder(shape = [], dtype = tf.float32)

    feat_s = self.f_extractor(x_s, n_in, n_units_f, keep_prob, reuse = False)
    #feat_s = self.transform(x_s, n_in, n_units_f, keep_prob, reuse = False)  # for transform
    feat_t = self.f_extractor(x_t, n_in, n_units_f, keep_prob, reuse = True)
    feat_t_l = self.f_extractor(x_t_l, n_in, n_units_f, keep_prob, reuse = True)

    feat = tf.concat([feat_s, feat_t], axis = 0)
    d = tf.concat([d_s, d_t], axis = 0)

    logits_d = self.classifier_d(feat, n_units_f, n_units_d, n_domains, keep_prob, reuse = False)
    probs_d = tf.nn.softmax(logits_d)

    loss_d = self.loss_cross_entropy(probs_d, d)

    logits_c_s = self.classifier_c(feat_s, n_units_f, n_units_c, n_classes, keep_prob, reuse = False)
    probs_c_s = tf.nn.softmax(logits_c_s)
    loss_c_s = self.loss_cross_entropy(probs_c_s, y_s)

    logits_c_t = self.predictor(feat_t_l, n_units_f, n_units_p, n_classes, keep_prob, reuse = False)
    probs_c_t = tf.nn.softmax(logits_c_t)
    loss_c_t = self.loss_cross_entropy(probs_c_t, y_t_l)

    logits_p_t = self.predictor(feat_t, n_units_f, n_units_p, n_classes, keep_prob, reuse = True)
    probs_p_t = tf.nn.softmax(logits_p_t)
    loss_p_t = self.loss_entropy(probs_p_t)

    loss_c_p_t = loss_c_t + loss_p_t

    loss_f = - lam * loss_d

    var_list_f = tf.trainable_variables('f_extractor')
    #var_list_f = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="f_extractor")
    var_list_tr = tf.trainable_variables('transform')
    var_list_d = tf.trainable_variables('classifier_d')
    var_list_c = tf.trainable_variables('classifier_c')
    var_list_p = tf.trainable_variables('predictor')

    var_list_f_c = var_list_f + var_list_c
    #var_list_f_c = var_list_f + var_list_tr + var_list_c  # for transform

    train_step_f = self.training(loss_f, learning_rate, var_list_f)
    train_step_d = self.training(loss_d, learning_rate, var_list_d)
    train_step_f_c = self.training(loss_c_s, learning_rate, var_list_f_c)
    train_step_c_p = self.training(loss_c_p_t, learning_rate, var_list_p)
    train_step_p = self.training(loss_p_t, learning_rate, var_list_p)

    acc_d =  self.accuracy(probs_d, d)
    acc_c_s =  self.accuracy(probs_c_s, y_s)
    acc_c_t =  self.accuracy(probs_c_t, y_t_l)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)

      history_loss_c_s_train = []
      history_loss_c_s_test = []
      history_acc_c_s_train = []
      history_acc_c_s_test = []

      history_loss_c_t_train = []
      history_loss_c_t_test = []
      history_acc_c_t_train = []
      history_acc_c_t_test = []

      history_loss_d_train = []
      history_loss_d_test = []
      history_acc_d_train = []
      history_acc_d_test = []

      history_loss_f_train = []
      history_loss_f_test = []

      for i in range(n_iter):
        # Training
        # class classification for f and c
        rand_index = np.random.choice(len(x_s_train), size = batch_size)
        x_batch = x_s_train[rand_index]
        y_batch = y_s_train[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}
        sess.run(train_step_f_c, feed_dict = feed_dict)

        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)

        history_loss_c_s_train.append(temp_loss_c_s)
        history_acc_c_s_train.append(temp_acc_c_s)

        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_c: ' + str(temp_loss_c_s) + '  Accuracy_c: ' + str(temp_acc_c_s))

        # prediction for p
        rand_index = np.random.choice(len(x_t_l_train), size = len(x_t_l_train))
        x_batch_l = x_t_l_train[rand_index]
        y_batch_l = y_t_l_train[rand_index]

        rand_index = np.random.choice(len(x_t_train), size = batch_size)
        x_batch = x_t_train[rand_index]
        y_batch = y_t_train[rand_index]

        feed_dict = {x_t: x_batch, x_t_l: x_batch_l, y_t_l: y_batch_l, keep_prob: 1.0}
        sess.run(train_step_c_p, feed_dict = feed_dict)

        feed_dict = {x_t_l: x_batch, y_t_l: y_batch, keep_prob: 1.0}
        temp_loss_c_t = sess.run(loss_c_t, feed_dict = feed_dict)
        temp_acc_c_t = sess.run(acc_c_t, feed_dict = feed_dict)

        history_loss_c_t_train.append(temp_loss_c_t)
        history_acc_c_t_train.append(temp_acc_c_t)

        # domain classification for d
        rand_index = np.random.choice(len(x_s_train), size = batch_size //2)
        x_batch_s = x_s_train[rand_index]
        d_batch_s = d_s_train[rand_index]

        rand_index = np.random.choice(len(x_t_train), size = batch_size //2)
        x_batch_t = x_t_train[rand_index]
        d_batch_t = d_t_train[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}
        sess.run(train_step_f, feed_dict = feed_dict)
        sess.run(train_step_d, feed_dict = feed_dict)

        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)

        history_loss_f_train.append(temp_loss_f)
        history_loss_d_train.append(temp_loss_d)
        history_acc_d_train.append(temp_acc_d)

        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_d: ' + str(temp_loss_d) + '  Accuracy_d: ' + str(temp_acc_d))

        # Test
        # class classification for s
        rand_index = np.random.choice(len(x_s_test), size = batch_size)
        x_batch = x_s_test[rand_index]
        y_batch = y_s_test[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}
        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)

        history_loss_c_s_test.append(temp_loss_c_s)
        history_acc_c_s_test.append(temp_acc_c_s)

        # class classification for t
        rand_index = np.random.choice(len(x_t_test), size = batch_size)
        x_batch = x_t_test[rand_index]
        y_batch = y_t_test[rand_index]

        feed_dict = {x_t_l: x_batch, y_t_l: y_batch, keep_prob: 1.0}
        temp_loss_c_t = sess.run(loss_c_t, feed_dict = feed_dict)
        temp_acc_c_t = sess.run(acc_c_t, feed_dict = feed_dict) 

        #feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}
        #temp_loss_c_t = sess.run(loss_c_s, feed_dict = feed_dict)
        #temp_acc_c_t = sess.run(acc_c_s, feed_dict = feed_dict) 

        history_loss_c_t_test.append(temp_loss_c_t)
        history_acc_c_t_test.append(temp_acc_c_t)

        # domain classification for f and d
        rand_index = np.random.choice(len(x_s_test), size = batch_size //2)
        x_batch_s = x_s_test[rand_index]
        d_batch_s = d_s_test[rand_index]

        rand_index = np.random.choice(len(x_t_test), size = batch_size //2)
        x_batch_t = x_t_test[rand_index]
        d_batch_t = d_t_test[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}
        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)

        history_loss_f_test.append(temp_loss_f)
        history_loss_d_test.append(temp_loss_d)
        history_acc_d_test.append(temp_acc_d)

      print ('-' * 100)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_c_s_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_c_s_test, 'r-', label = 'Test')
      ax1.set_title('Loss_c_s')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_s_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_c_s_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_c_s')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_d_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_d_test, 'r-', label = 'Test')
      ax1.set_title('Loss_d')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_d_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_d_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_d')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_c_t_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_c_t_test, 'r-', label = 'Test')
      ax1.set_title('Loss_c_t')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_t_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_c_t_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_c_t')
      ax2.legend(loc = 'lower right')

      plt.show()

Parameters

n_in = 2
n_units_f = 15
n_units_d = 15
n_units_c = 15
n_units_p = 15
n_domains = 2
n_classes = 3
lam = 1.0
learning_rate = 0.01
n_iter = 300
batch_size = 32
show_step = 100
model_path = 'datalab/model'

Output

is_saving = False

semi.fit(x_s_train, x_s_test, y_s_train, y_s_test, \
         x_t_train, x_t_test, y_t_train, y_t_test, \
         x_t_l_train, y_t_l_train, n_in, n_units_f, n_units_d, \
         n_domains, n_units_c, n_classes, n_units_p, lam, \
         learning_rate, n_iter, batch_size, show_step, is_saving, model_path)

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0