More than 5 years have passed since last update.

Adversarial Residual Transform Network (ARTN) の実装に関するメモ(2)

Last updated at Posted at 2018-08-29


1. Unsupervised Domain Adaptation with Adversarial Residual Transform Networks


2. Adversarial Residual Transform Network (ARTN) の実装に関するメモ


# generate toy data

n_s = 1000
n_t = 1000
n_classes = 3

p_s = [0.4, 0.3, 0.3]
p_t = [0.4, 0.3, 0.3]

mu_s_1 = [0.0, 0.0]
sig_s_1 = [0.2, 0.2]
mu_s_2 = [0.0, 1.0]
sig_s_2 = sig_s_1
mu_s_3 = [0.5, 0.5]
sig_s_3 = sig_s_1

# pattern 1
#mu_t_1 = [2.0, 0.5]
#sig_t_1 = sig_s_1
#mu_t_2 = [2.0, 1.5]
#sig_t_2 = sig_s_1
#mu_t_3 = [2.5, 1.0]
#sig_t_3 = sig_s_1

# pattern 2
#mu_t_1 = [0.5, 0.5]
#sig_t_1 = sig_s_1
#mu_t_2 = [0.0, 0.0]
#sig_t_2 = sig_s_1
#mu_t_3 = [0.0, 1.0]
#sig_t_3 = sig_s_1

# pattern 3
#mu_t_1 = [2.5, 1.0]
#sig_t_1 = sig_s_1
#mu_t_2 = [2.0, 0.5]
#sig_t_2 = sig_s_1
#mu_t_3 = [2.0, 1.5]
#sig_t_3 = sig_s_1

# pattern 4
#mu_t_1 = [1.5, 0.0]
#sig_t_1 = [0.1, 0.001]
#mu_t_2 = [2.0, 0.0]
#sig_t_2 = [0.1, 0.001]
#mu_t_3 = [2.5, 0.0]
#sig_t_3 = [0.1, 0.001]

# pattern 5
mu_t_1 = [1.7, 0.3]
sig_t_1 = sig_s_1
mu_t_2 = [2.0, 1.0]
sig_t_2 = sig_s_1
mu_t_3 = [2.8, 1.3]
sig_t_3 = sig_s_1

y_s = np.random.choice(n_classes, n_s, p = p_s)
y_s_one_hot = np.identity(n_classes)[y_s].astype(np.int32)
y_t = np.random.choice(n_classes, n_t, p = p_t)
y_t_one_hot = np.identity(n_classes)[y_t].astype(np.int32)

x_s = []
for label in y_s:
  if label == 0:
    x = np.random.normal(loc = mu_s_1, scale = sig_s_1)
  elif label == 1:
    x = np.random.normal(loc = mu_s_2, scale = sig_s_2)
    x = np.random.normal(loc = mu_s_3, scale = sig_s_3)
x_s = np.reshape(x_s, (-1, 2))

x_t = []
for label in y_t:
  if label == 0:
    x = np.random.normal(loc = mu_t_1, scale = sig_t_1)
  elif label == 1:
    x = np.random.normal(loc = mu_t_2, scale = sig_t_2)
    x = np.random.normal(loc = mu_t_3, scale = sig_t_3)
x_t = np.reshape(x_t, (-1, 2))

plt.scatter(x_s[:,0], x_s[:,1], c=y_s, cmap = 'RdYlGn', alpha = 0.5)
plt.scatter(x_t[:,0], x_t[:,1], c=y_t, cmap = 'RdYlGn')

n_domains = 2
d_s = np.ones(shape = [n_s], dtype = np.int32) * 0
d_s_one_hot = np.identity(n_domains)[d_s]
d_t = np.ones(shape = [n_t], dtype = np.int32) * 1
d_t_one_hot = np.identity(n_domains)[d_t]

n_s_train = np.int(n_s * 0.8)
n_t_train = np.int(n_t * 0.8)

x_s_train = x_s[:n_s_train]
x_s_test = x_s[n_s_train:]
y_s_train = y_s_one_hot[:n_s_train]
y_s_test = y_s_one_hot[n_s_train:]
d_s_train = d_s_one_hot[:n_s_train]
d_s_test = d_s_one_hot[n_s_train:]

x_t_train = x_t[:n_t_train]
x_t_test = x_t[n_t_train:]
y_t_train = y_t_one_hot[:n_t_train]
y_t_test = y_t_one_hot[n_t_train:]
d_t_train = d_t_one_hot[:n_t_train]
d_t_test = d_t_one_hot[n_t_train:]

pattern 1

pattern 2

pattern 3

pattern 4

pattern 5

Sample Code

# Adversarial Residual Transform Network

class ARTN():
  def __init__(self):

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def f_extractor(self, x, n_in, n_units, keep_prob, reuse = False):
    with tf.variable_scope('f_extractor', reuse = reuse):
      w = self.weight_variable('w', [n_in, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(x, w) + b

      # batch norm
      batch_mean, batch_var = tf.nn.moments(f, [0])
      f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def transform(self, x, n_in, n_units, keep_prob, reuse = False):
    f_s = self.f_extractor(x, n_in, n_units, keep_prob, reuse = reuse)

    with tf.variable_scope('transform', reuse = reuse):
      w = self.weight_variable('w', [n_units, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(f_s, w) + b

      # residual connection
      f += f_s 

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(f, [0])
      #f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def classifier_d(self, x, n_units_f, n_units_d, n_domains, keep_prob, reuse = False):

    with tf.variable_scope('classifier_d', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_d])
      b_1 = self.bias_variable('b_1', [n_units_d])

      d = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(d, [0])
      #d = (d - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      d = tf.nn.relu(d)

      # dropout
      #d = tf.nn.dropout(d, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_d, n_domains])
      b_2 = self.bias_variable('b_2', [n_domains])

      d = tf.matmul(d, w_2) + b_2
      logits = d

    return logits

  def classifier_c(self, x, n_units_f, n_units_c, n_classes, keep_prob, reuse = False):

    with tf.variable_scope('classifier_c', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_c])
      b_1 = self.bias_variable('b_1', [n_units_c])

      l = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(l, [0])
      #l = (l - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      l = tf.nn.relu(l)

      # dropout
      #l = tf.nn.dropout(l, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_c, n_classes])
      b_2 = self.bias_variable('b_2', [n_classes])

      l = tf.matmul(l, w_2) + b_2
      logits = l

    return logits

  def gradient_reversal(self, f, lam, n_units_f, batch_size):
    i = tf.ones(shape = [batch_size, n_units_f], dtype = tf.float32)

    return - lam * i + tf.stop_gradient(f + lam * i)

  def loss_cross_entropy(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_gd(self, loss, learning_rate, var_list):
    #optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss, var_list = var_list)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, x_s_train, x_s_test, y_s_train, y_s_test, d_s_train, d_s_test, \
                 x_t_train, x_t_test, y_t_train, y_t_test, d_t_train, d_t_test, \
                 n_in, n_units_f, n_units_d, n_domains, n_units_c, n_classes, lam, \
                 learning_rate, n_iter, batch_size, show_step, is_saving, model_path):


    x_s = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_s = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_s = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    x_t = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_t = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_t = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    keep_prob = tf.placeholder(shape = [], dtype = tf.float32)

    feat_s = self.transform(x_s, n_in, n_units_f, keep_prob, reuse = False)
    feat_t = self.f_extractor(x_t, n_in, n_units_f, keep_prob, reuse = True)

    feat = tf.concat([feat_s, feat_t], axis = 0)
    d = tf.concat([d_s, d_t], axis = 0)

    logits_d = self.classifier_d(feat, n_units_f, n_units_d, n_domains, keep_prob, reuse = False)
    probs_d = tf.nn.softmax(logits_d)

    loss_d = self.loss_cross_entropy(probs_d, d)

    logits_c_s = self.classifier_c(feat_s, n_units_f, n_units_c, n_classes, keep_prob, reuse = False)
    probs_c_s = tf.nn.softmax(logits_c_s)
    loss_c_s = self.loss_cross_entropy(probs_c_s, y_s)

    logits_c_t = self.classifier_c(feat_t, n_units_f, n_units_c, n_classes, keep_prob, reuse = True)
    probs_c_t = tf.nn.softmax(logits_c_t)

    loss_f = - lam * loss_d

    var_list_f = tf.trainable_variables('f_extractor')
    #var_list_f = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="f_extractor")
    var_list_tr = tf.trainable_variables('transform')
    var_list_d = tf.trainable_variables('classifier_d')
    var_list_c = tf.trainable_variables('classifier_c')
    var_list_f_c = var_list_f + var_list_tr + var_list_c
    var_list_f_tr = var_list_f + var_list_tr    

    train_step_f = self.training(loss_f, learning_rate, var_list_f)
    train_step_d = self.training(loss_d, learning_rate, var_list_d)
    train_step_f_c = self.training(loss_c_s, learning_rate, var_list_f_c)

    acc_d =  self.accuracy(probs_d, d)
    acc_c_s =  self.accuracy(probs_c_s, y_s)
    acc_c_t =  self.accuracy(probs_c_t, y_t)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:


      history_loss_c_s_train = []
      history_loss_c_s_test = []
      history_acc_c_s_train = []
      history_acc_c_s_test = []

      history_loss_c_t_test = []
      history_acc_c_t_test = []

      history_loss_d_train = []
      history_loss_d_test = []
      history_acc_d_train = []
      history_acc_d_test = []

      history_loss_f_train = []
      history_loss_f_test = []

      for i in range(n_iter):
        # Training
        # class classification for f_s , tr and c
        rand_index = np.random.choice(len(x_s_train), size = batch_size)
        x_batch = x_s_train[rand_index]
        y_batch = y_s_train[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}

        sess.run(train_step_f_c, feed_dict = feed_dict)

        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)


        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_c: ' + str(temp_loss_c_s) + '  Accuracy_c: ' + str(temp_acc_c_s))

        # domain classification for f_t and d
        rand_index = np.random.choice(len(x_s_train), size = batch_size //2)
        x_batch_s = x_s_train[rand_index]
        d_batch_s = d_s_train[rand_index]

        rand_index = np.random.choice(len(x_t_train), size = batch_size //2)
        x_batch_t = x_t_train[rand_index]
        d_batch_t = d_t_train[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}

        sess.run(train_step_f, feed_dict = feed_dict)
        sess.run(train_step_d, feed_dict = feed_dict)

        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)


        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_d: ' + str(temp_loss_d) + '  Accuracy_d: ' + str(temp_acc_d))

        # Test
        # class classification for s
        rand_index = np.random.choice(len(x_s_test), size = batch_size)
        x_batch = x_s_test[rand_index]
        y_batch = y_s_test[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}

        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)


        # class classification for t
        rand_index = np.random.choice(len(x_t_test), size = batch_size)
        x_batch = x_t_test[rand_index]
        y_batch = y_t_test[rand_index]

        feed_dict = {x_t: x_batch, y_t: y_batch, keep_prob: 1.0}

        temp_acc_c_t = sess.run(acc_c_t, feed_dict = feed_dict) 


        # domain classification for f and d
        rand_index = np.random.choice(len(x_s_test), size = batch_size //2)
        x_batch_s = x_s_test[rand_index]
        d_batch_s = d_s_test[rand_index]

        rand_index = np.random.choice(len(x_t_test), size = batch_size //2)
        x_batch_t = x_t_test[rand_index]
        d_batch_t = d_t_test[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}

        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)


      print ('-' * 100)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_c_s_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_c_s_test, 'r-', label = 'Test')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_s_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_c_s_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_d_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_d_test, 'r-', label = 'Test')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_d_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_d_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_f_train, 'b-', label = 'Training')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_t_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.legend(loc = 'lower right')



n_in = 2
n_units_f = 15
n_units_d = 15
n_units_c = 15
n_domains = 2
n_classes = 3
lam = 1.0
learning_rate = 0.01
n_iter = 500
batch_size = 32
show_step = 100
model_path = 'datalab/model'


is_saving = False

artn.fit(x_s_train, x_s_test, y_s_train, y_s_test, d_s_train, d_s_test, \
         x_t_train, x_t_test, y_t_train, y_t_test, d_t_train, d_t_test, \
         n_in, n_units_f, n_units_d, n_domains, n_units_c, n_classes, lam, \
         learning_rate, n_iter, batch_size, show_step, is_saving, model_path)

pattern 1

pattern 2

pattern 3

pattern 4

pattern 5


Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up