0
0

More than 5 years have passed since last update.

Adversarial Residual Transform Network (ARTN) の実装に関するメモ(2)

Last updated at Posted at 2018-08-29

Reference

1. Unsupervised Domain Adaptation with Adversarial Residual Transform Networks

image.png

2. Adversarial Residual Transform Network (ARTN) の実装に関するメモ

Data

# generate toy data

n_s = 1000
n_t = 1000
n_classes = 3

p_s = [0.4, 0.3, 0.3]
p_t = [0.4, 0.3, 0.3]

mu_s_1 = [0.0, 0.0]
sig_s_1 = [0.2, 0.2]
mu_s_2 = [0.0, 1.0]
sig_s_2 = sig_s_1
mu_s_3 = [0.5, 0.5]
sig_s_3 = sig_s_1

# pattern 1
#mu_t_1 = [2.0, 0.5]
#sig_t_1 = sig_s_1
#mu_t_2 = [2.0, 1.5]
#sig_t_2 = sig_s_1
#mu_t_3 = [2.5, 1.0]
#sig_t_3 = sig_s_1

# pattern 2
#mu_t_1 = [0.5, 0.5]
#sig_t_1 = sig_s_1
#mu_t_2 = [0.0, 0.0]
#sig_t_2 = sig_s_1
#mu_t_3 = [0.0, 1.0]
#sig_t_3 = sig_s_1

# pattern 3
#mu_t_1 = [2.5, 1.0]
#sig_t_1 = sig_s_1
#mu_t_2 = [2.0, 0.5]
#sig_t_2 = sig_s_1
#mu_t_3 = [2.0, 1.5]
#sig_t_3 = sig_s_1

# pattern 4
#mu_t_1 = [1.5, 0.0]
#sig_t_1 = [0.1, 0.001]
#mu_t_2 = [2.0, 0.0]
#sig_t_2 = [0.1, 0.001]
#mu_t_3 = [2.5, 0.0]
#sig_t_3 = [0.1, 0.001]

# pattern 5
mu_t_1 = [1.7, 0.3]
sig_t_1 = sig_s_1
mu_t_2 = [2.0, 1.0]
sig_t_2 = sig_s_1
mu_t_3 = [2.8, 1.3]
sig_t_3 = sig_s_1

y_s = np.random.choice(n_classes, n_s, p = p_s)
y_s_one_hot = np.identity(n_classes)[y_s].astype(np.int32)
y_t = np.random.choice(n_classes, n_t, p = p_t)
y_t_one_hot = np.identity(n_classes)[y_t].astype(np.int32)

x_s = []
for label in y_s:
  if label == 0:
    x = np.random.normal(loc = mu_s_1, scale = sig_s_1)
  elif label == 1:
    x = np.random.normal(loc = mu_s_2, scale = sig_s_2)
  else:
    x = np.random.normal(loc = mu_s_3, scale = sig_s_3)
  x_s.append(x)
x_s = np.reshape(x_s, (-1, 2))

x_t = []
for label in y_t:
  if label == 0:
    x = np.random.normal(loc = mu_t_1, scale = sig_t_1)
  elif label == 1:
    x = np.random.normal(loc = mu_t_2, scale = sig_t_2)
  else:
    x = np.random.normal(loc = mu_t_3, scale = sig_t_3)
  x_t.append(x)
x_t = np.reshape(x_t, (-1, 2))


plt.scatter(x_s[:,0], x_s[:,1], c=y_s, cmap = 'RdYlGn', alpha = 0.5)
plt.scatter(x_t[:,0], x_t[:,1], c=y_t, cmap = 'RdYlGn')
plt.show()

n_domains = 2
d_s = np.ones(shape = [n_s], dtype = np.int32) * 0
d_s_one_hot = np.identity(n_domains)[d_s]
d_t = np.ones(shape = [n_t], dtype = np.int32) * 1
d_t_one_hot = np.identity(n_domains)[d_t]

n_s_train = np.int(n_s * 0.8)
n_t_train = np.int(n_t * 0.8)

x_s_train = x_s[:n_s_train]
x_s_test = x_s[n_s_train:]
y_s_train = y_s_one_hot[:n_s_train]
y_s_test = y_s_one_hot[n_s_train:]
d_s_train = d_s_one_hot[:n_s_train]
d_s_test = d_s_one_hot[n_s_train:]

x_t_train = x_t[:n_t_train]
x_t_test = x_t[n_t_train:]
y_t_train = y_t_one_hot[:n_t_train]
y_t_test = y_t_one_hot[n_t_train:]
d_t_train = d_t_one_hot[:n_t_train]
d_t_test = d_t_one_hot[n_t_train:]

pattern 1
image.png

pattern 2
image.png

pattern 3
image.png

pattern 4
image.png

pattern 5
image.png

Sample Code

# Adversarial Residual Transform Network

class ARTN():
  def __init__(self):
    pass

  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)

  def f_extractor(self, x, n_in, n_units, keep_prob, reuse = False):
    with tf.variable_scope('f_extractor', reuse = reuse):
      w = self.weight_variable('w', [n_in, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(x, w) + b

      # batch norm
      batch_mean, batch_var = tf.nn.moments(f, [0])
      f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def transform(self, x, n_in, n_units, keep_prob, reuse = False):
    f_s = self.f_extractor(x, n_in, n_units, keep_prob, reuse = reuse)

    with tf.variable_scope('transform', reuse = reuse):
      w = self.weight_variable('w', [n_units, n_units])
      b = self.bias_variable('b', [n_units])

      f = tf.matmul(f_s, w) + b

      # residual connection
      f += f_s 

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(f, [0])
      #f = (f - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # dropout
      #f = tf.nn.dropout(f, keep_prob)

      # relu
      f = tf.nn.relu(f)

      # leaky relu
      #f = tf.maximum(0.2 * f, f)

    return f

  def classifier_d(self, x, n_units_f, n_units_d, n_domains, keep_prob, reuse = False):

    with tf.variable_scope('classifier_d', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_d])
      b_1 = self.bias_variable('b_1', [n_units_d])

      d = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(d, [0])
      #d = (d - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      d = tf.nn.relu(d)

      # dropout
      #d = tf.nn.dropout(d, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_d, n_domains])
      b_2 = self.bias_variable('b_2', [n_domains])

      d = tf.matmul(d, w_2) + b_2
      logits = d

    return logits

  def classifier_c(self, x, n_units_f, n_units_c, n_classes, keep_prob, reuse = False):

    with tf.variable_scope('classifier_c', reuse = reuse):
      w_1 = self.weight_variable('w_1', [n_units_f, n_units_c])
      b_1 = self.bias_variable('b_1', [n_units_c])

      l = tf.matmul(x, w_1) + b_1

      # batch norm
      #batch_mean, batch_var = tf.nn.moments(l, [0])
      #l = (l - batch_mean) / (tf.sqrt(batch_var) + 1e-10)

      # relu
      l = tf.nn.relu(l)

      # dropout
      #l = tf.nn.dropout(l, keep_prob)

      w_2 = self.weight_variable('w_2', [n_units_c, n_classes])
      b_2 = self.bias_variable('b_2', [n_classes])

      l = tf.matmul(l, w_2) + b_2
      logits = l

    return logits

  def gradient_reversal(self, f, lam, n_units_f, batch_size):
    i = tf.ones(shape = [batch_size, n_units_f], dtype = tf.float32)

    return - lam * i + tf.stop_gradient(f + lam * i)

  def loss_cross_entropy(self, y, t):
    cross_entropy = - tf.reduce_mean(tf.reduce_sum(t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)), axis = 1))
    return cross_entropy

  def accuracy(self, y, t):
    correct_preds = tf.equal(tf.argmax(y, axis = 1), tf.argmax(t, axis = 1))
    accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))
    return accuracy

  def training(self, loss, learning_rate, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_gd(self, loss, learning_rate, var_list):
    #optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    train_step = optimizer.minimize(loss, var_list = var_list)
    return train_step

  def training_clipped(self, loss, learning_rate, clip_norm, var_list):
    optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

    grads_and_vars = optimizer.compute_gradients(loss, var_list = var_list)
    clipped_grads_and_vars = [(tf.clip_by_norm(grad, clip_norm = clip_norm), \
                             var) for grad, var in grads_and_vars]
    train_step = optimizer.apply_gradients(clipped_grads_and_vars)

    return train_step

  def fit(self, x_s_train, x_s_test, y_s_train, y_s_test, d_s_train, d_s_test, \
                 x_t_train, x_t_test, y_t_train, y_t_test, d_t_train, d_t_test, \
                 n_in, n_units_f, n_units_d, n_domains, n_units_c, n_classes, lam, \
                 learning_rate, n_iter, batch_size, show_step, is_saving, model_path):

    tf.reset_default_graph()

    x_s = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_s = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_s = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    x_t = tf.placeholder(shape = [None, n_in], dtype = tf.float32)
    y_t = tf.placeholder(shape = [None, n_classes], dtype = tf.float32)
    d_t = tf.placeholder(shape = [None, n_domains], dtype = tf.float32) 
    keep_prob = tf.placeholder(shape = [], dtype = tf.float32)

    feat_s = self.transform(x_s, n_in, n_units_f, keep_prob, reuse = False)
    feat_t = self.f_extractor(x_t, n_in, n_units_f, keep_prob, reuse = True)

    feat = tf.concat([feat_s, feat_t], axis = 0)
    d = tf.concat([d_s, d_t], axis = 0)

    logits_d = self.classifier_d(feat, n_units_f, n_units_d, n_domains, keep_prob, reuse = False)
    probs_d = tf.nn.softmax(logits_d)

    loss_d = self.loss_cross_entropy(probs_d, d)

    logits_c_s = self.classifier_c(feat_s, n_units_f, n_units_c, n_classes, keep_prob, reuse = False)
    probs_c_s = tf.nn.softmax(logits_c_s)
    loss_c_s = self.loss_cross_entropy(probs_c_s, y_s)

    logits_c_t = self.classifier_c(feat_t, n_units_f, n_units_c, n_classes, keep_prob, reuse = True)
    probs_c_t = tf.nn.softmax(logits_c_t)

    loss_f = - lam * loss_d

    var_list_f = tf.trainable_variables('f_extractor')
    #var_list_f = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="f_extractor")
    var_list_tr = tf.trainable_variables('transform')
    var_list_d = tf.trainable_variables('classifier_d')
    var_list_c = tf.trainable_variables('classifier_c')
    var_list_f_c = var_list_f + var_list_tr + var_list_c
    var_list_f_tr = var_list_f + var_list_tr    

    train_step_f = self.training(loss_f, learning_rate, var_list_f)
    train_step_d = self.training(loss_d, learning_rate, var_list_d)
    train_step_f_c = self.training(loss_c_s, learning_rate, var_list_f_c)

    acc_d =  self.accuracy(probs_d, d)
    acc_c_s =  self.accuracy(probs_c_s, y_s)
    acc_c_t =  self.accuracy(probs_c_t, y_t)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:

      sess.run(init)

      history_loss_c_s_train = []
      history_loss_c_s_test = []
      history_acc_c_s_train = []
      history_acc_c_s_test = []

      history_loss_c_t_test = []
      history_acc_c_t_test = []

      history_loss_d_train = []
      history_loss_d_test = []
      history_acc_d_train = []
      history_acc_d_test = []

      history_loss_f_train = []
      history_loss_f_test = []

      for i in range(n_iter):
        # Training
        # class classification for f_s , tr and c
        rand_index = np.random.choice(len(x_s_train), size = batch_size)
        x_batch = x_s_train[rand_index]
        y_batch = y_s_train[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}

        sess.run(train_step_f_c, feed_dict = feed_dict)

        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)

        history_loss_c_s_train.append(temp_loss_c_s)
        history_acc_c_s_train.append(temp_acc_c_s)

        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_c: ' + str(temp_loss_c_s) + '  Accuracy_c: ' + str(temp_acc_c_s))

        # domain classification for f_t and d
        rand_index = np.random.choice(len(x_s_train), size = batch_size //2)
        x_batch_s = x_s_train[rand_index]
        d_batch_s = d_s_train[rand_index]

        rand_index = np.random.choice(len(x_t_train), size = batch_size //2)
        x_batch_t = x_t_train[rand_index]
        d_batch_t = d_t_train[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}

        sess.run(train_step_f, feed_dict = feed_dict)
        sess.run(train_step_d, feed_dict = feed_dict)

        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)

        history_loss_f_train.append(temp_loss_f)
        history_loss_d_train.append(temp_loss_d)
        history_acc_d_train.append(temp_acc_d)

        if (i + 1) % show_step == 0:
          print ('-' * 100)
          print ('Iteration: ' + str(i + 1) + \
                 '  Loss_d: ' + str(temp_loss_d) + '  Accuracy_d: ' + str(temp_acc_d))

        # Test
        # class classification for s
        rand_index = np.random.choice(len(x_s_test), size = batch_size)
        x_batch = x_s_test[rand_index]
        y_batch = y_s_test[rand_index]

        feed_dict = {x_s: x_batch, y_s: y_batch, keep_prob: 1.0}

        temp_loss_c_s = sess.run(loss_c_s, feed_dict = feed_dict)
        temp_acc_c_s = sess.run(acc_c_s, feed_dict = feed_dict)

        history_loss_c_s_test.append(temp_loss_c_s)
        history_acc_c_s_test.append(temp_acc_c_s)

        # class classification for t
        rand_index = np.random.choice(len(x_t_test), size = batch_size)
        x_batch = x_t_test[rand_index]
        y_batch = y_t_test[rand_index]

        feed_dict = {x_t: x_batch, y_t: y_batch, keep_prob: 1.0}

        temp_acc_c_t = sess.run(acc_c_t, feed_dict = feed_dict) 

        history_acc_c_t_test.append(temp_acc_c_t)

        # domain classification for f and d
        rand_index = np.random.choice(len(x_s_test), size = batch_size //2)
        x_batch_s = x_s_test[rand_index]
        d_batch_s = d_s_test[rand_index]

        rand_index = np.random.choice(len(x_t_test), size = batch_size //2)
        x_batch_t = x_t_test[rand_index]
        d_batch_t = d_t_test[rand_index]

        feed_dict = {x_s: x_batch_s, d_s: d_batch_s, x_t: x_batch_t, d_t: d_batch_t, keep_prob: 1.0}

        temp_loss_f = sess.run(loss_f, feed_dict = feed_dict)
        temp_loss_d = sess.run(loss_d, feed_dict = feed_dict)
        temp_acc_d = sess.run(acc_d, feed_dict = feed_dict)

        history_loss_f_test.append(temp_loss_f)
        history_loss_d_test.append(temp_loss_d)
        history_acc_d_test.append(temp_acc_d)

      print ('-' * 100)    
      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_c_s_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_c_s_test, 'r-', label = 'Test')
      ax1.set_title('Loss_c_s')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_s_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_c_s_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_c_s')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_d_train, 'b-', label = 'Training')
      ax1.plot(range(n_iter), history_loss_d_test, 'r-', label = 'Test')
      ax1.set_title('Loss_d')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_d_train, 'b-', label = 'Training')
      ax2.plot(range(n_iter), history_acc_d_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_d')
      ax2.legend(loc = 'lower right')

      fig = plt.figure(figsize = (10, 3))
      ax1 = fig.add_subplot(1, 2, 1)
      ax1.plot(range(n_iter), history_loss_f_train, 'b-', label = 'Training')
      ax1.set_title('Loss_f')
      ax1.legend(loc = 'upper right')

      ax2 = fig.add_subplot(1, 2, 2)
      ax2.plot(range(n_iter), history_acc_c_t_test, 'r-', label = 'Test')
      ax2.set_ylim(0.0, 1.0)
      ax2.set_title('Accuracy_c_t')
      ax2.legend(loc = 'lower right')

      plt.show()

Parameters

n_in = 2
n_units_f = 15
n_units_d = 15
n_units_c = 15
n_domains = 2
n_classes = 3
lam = 1.0
learning_rate = 0.01
n_iter = 500
batch_size = 32
show_step = 100
model_path = 'datalab/model'

Output

is_saving = False

artn.fit(x_s_train, x_s_test, y_s_train, y_s_test, d_s_train, d_s_test, \
         x_t_train, x_t_test, y_t_train, y_t_test, d_t_train, d_t_test, \
         n_in, n_units_f, n_units_d, n_domains, n_units_c, n_classes, lam, \
         learning_rate, n_iter, batch_size, show_step, is_saving, model_path)

pattern 1
image.png

pattern 2
image.png

pattern 3
image.png

pattern 4
image.png

pattern 5
image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0