9
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

NextremerAdvent Calendar 2016

Day 25

Decoupled Neural Interfaces のFeedforward部分のtensorflow実装の概略

Last updated at Posted at 2016-12-25

今回紹介する論文の著者による説明が以下のブログに掲載されていますので、そちらもご覧ください。
https://deepmind.com/blog/decoupled-neural-networks-using-synthetic-gradients/

#背景と提案手法のアイデア

論文の背景は、より一般的な分散NNモデルの学習(Model Parallel)です。
理由としては、より柔軟なモデルをより効率的に訓練したいということです。例えば、中間データを保存する必要の無い、学習データとSGDで実行できるEnd to endの学習手法のさらなる拡張などです。

以下の図のように分散環境を表す有向グラフを考えることができます。
各頂点が分割されたモデルの断片or1台機械に対応しています。

image

すると、ひとつの大きなニューラルネットワーク(NN)を分散したデータフローで理解することができます。まず、頂点でNNが実行されます。すると、NNの中間層の出力を入力とするNNのFooward計算に必要な入力が有向辺の終点とみなせます。

このような分散環境の学習を非同期性を考慮しながら、頂点数でスケールアウトさせたいです。

問題点として、Forward, Backwardの考え方は「硬すぎ」ました。例えば、いろいろなLockingが生じます。
ニューラルネットのForward(Backward)Lockingと呼ばれるものは、各々の層は自分より下位(上位)層の全計算完了まで待つことを強いられます。最悪、ネットワークのの深さに比例する待ち時間もあります。非同期性と相性悪いのがネックでした。

そこで、DNI: Decoupled Neural Interfaceと呼ばれる手法を提案しました。Feedforward NNとRNNの学習から同期性を取り除いたことと、より一般的な分散モデル学習への指針を示したことです。

分散環境の送信値が間違っていてもOK!と思うことです。最初は送り手の頂点だって不正確な値を出力するのですから。
有向辺に流れる送信値を近似するNNを用意します。頂点から送信された値との「差分」が目的関数です。辺の個数と同じだけ目的関数があります。頂点のモデルの学習に伴い正確になればよいと考えます。分散環境全体を見て最適化していきます。送信値の例としては、Back Propagationのデルタ値や、Forward計算の途中値などがあります。

#Feedforward NetworkのBP勾配の近似例
以下の近似式になります。implicitな多層にわたるパラメータ(重み)が無視されています。
image

#隣接4層が1層4個に分割する場合

1番目の近似NN出力から第0層の重みを更新

2番目の近似NN出力から第1層の重みを更新
第1層誤差&第0層出力から1番目の近似NNの
重み更新

3番目の近似NN出力から第2層の重みを更新
第2層誤差&第1層出力から2番目の近似NNの
重み更新

教師信号から第3層の重みを更新
第3層誤差&第2層出力から3番目の近似NNの
重み更新

このとき近似NNが3個あるので、3個の目的関数を定義しています。

#Tensorflow実装の概略
今回は、4層のCNNの各々の層のBP勾配の値をCNNによって近似するというコードのうちDNIのアイデアに直接関わる部分を説明します。

最初に以下のような層を作成するためのオブジェクトを定義しておきます。

a_part_of_main.py
class layerInfo():
	def __init__(self,):
		self.name = ''
		self.out={}
		self.var={}
		self.synthetic_grad={}
	def get_single_layer_info(self,):
		return self.out[self.name], self.var[self.name+'_w'],self.var[self.name+'_b'],self.synthetic_grad[self.name]
	def set_single_layer_info(self,layer_func, *args, **kwargs):
		try:
			with tf.variable_scope(self.name) as vs:
				#print "self.name as scope"== self.name
				self.out[self.name], self.var[self.name+'_w'],self.var[self.name+'_b'],self.synthetic_grad[self.name] = layer_func(*args, **kwargs)
		except:
			raise

これから記述するコードの断片でselfと書いてあるのは、Model-classと呼ばれるオブジェクト自身です。

これから3つのメソッドの関数定義を説明します。
注目すべき点は、

(1)どのtf.Variableが近似NN用につかわれているか。
(2)複数の目的関数から計算される誤差の保存
(3)DNIの近似NNだけの重みを変更する
まず層の定義自体です。

a_part_of_utils.py

def conv2d(inputs, output_size, kernel_size, stride, 
			weights_initializer=tf.contrib.layers.xavier_initializer(),
			biases_initializer=tf.zeros_initializer, synthetic=False,
			batch_norm = True,
			activation_fn=tf.nn.relu, padding='SAME', name='conv2d'):
	
	var = {}
	print kernel_size
	kernel_shape = [kernel_size[0], kernel_size[1], inputs.get_shape()[-1], output_size]
	stride  = [1, 1, stride[0], stride[1]]
	with tf.variable_scope(name):
		var['w'] = tf.get_variable('w', kernel_shape,
			tf.float32, initializer=weights_initializer)
		conv = tf.nn.conv2d(inputs, var['w'], stride, padding=padding)
		var['b'] = tf.get_variable('b', [output_size], tf.float32, initializer=biases_initializer)
		out = tf.nn.bias_add(conv, var['b'])
		
		if batch_norm:
			out = tf.contrib.layers.batch_norm(out)
		if activation_fn != None:
			out = activation_fn(out)

		if synthetic:
			out_shape = out.get_shape()
			h1, var['l1_w'], var['l1_b'] = conv2d(out, 128, [5,5], [1,1],
								tf.zeros_initializer, tf.zeros_initializer, batch_norm=True, activation_fn=tf.nn.relu, name='l1')
			h2, var['l2_w'], var['l2_b'] = conv2d(h1, 128, [5,5], [1,1],
								tf.zeros_initializer, tf.zeros_initializer, batch_norm=True, activation_fn=tf.nn.relu, name='l2')
			synthetic_grad, var['l3_w'], var['l3_b'] = conv2d(h2, 128, [5,5], [1,1],
								tf.zeros_initializer, tf.zeros_initializer, batch_norm=False, activation_fn=None, name='l3')
			return out, var['w'], var['b'], synthetic_grad
		else:		
			return out, var['w'], var['b'], np.float32(0.0)

この層を定義している部分では、返り値に層の出力に近似された勾配が含まれていることに注目します。もちろん、近似として、使用されていないレイヤーの場合は、ダミー値0.0を返します。

a_part_of_main.py

	def build_cnn_model(self):
		pool_types={'l1':'max','l2':'average','l3':'average'}
		self.imgs = tf.placeholder('float32', [self.batch_size, self.input_dims])
		self.img_reshape = tf.reshape(self.imgs, [self.batch_size, self.w, self.h, self.channel])	
		self.layer = layerInfo()

		options={
			"weights_initializer":self.weight_initializer, 
			"biases_initializer":self.bias_initializer, 
			"synthetic":self.synthetic,
			"batch_norm":True, 
			"activation_fn":tf.nn.relu,
		}

		input_to_layer=self.img_reshape
		for name in['l1','l2','l3']:
			self.layer.name = name
			args=(input_to_layer, 128, [5,5], [1,1])
			options["name"]=name+"_conv2d"
			self.layer.set_single_layer_info(conv2d, *args, **options )
			print self.layer.out['l1']
			print "name==", self.layer.name
			self.layer.out[name+'_pool'] = pooling(self.layer.out[name], kernel_size=[3,3], stride=[1,1], type=pool_types[name])
			input_to_layer=self.layer.out[name+'_pool']	
		
		self.layer.out['l3_reshape'] = tf.reshape(self.layer.out['l3_pool'], [self.batch_size, -1])
		name='l4'
		self.layer.name=name
		args = (self.layer.out['l3_reshape'], self.output_size)
		del options["batch_norm"]
		self.layer.set_single_layer_info(linear, *args, **options)
		self.out_logit = tf.nn.softmax(self.layer.out['l4'])

		self.out_argmax = tf.argmax(self.out_logit, 1)
		self.labels = tf.placeholder('int32', [self.batch_size])
		self.loss_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(self.layer.out['l4'], self.labels)
		self.loss = tf.reduce_sum(self.loss_entropy)/self.batch_size

		if self.synthetic:
			for name in ['l1','l2','l3','l4']:
				self.grad_output[name] = tf.gradients(self.loss, self.layer.out[name])
			for k in self.grad_output.keys():
				self.grad_loss.append(tf.reduce_sum(tf.square(self.synthetic_grad[k]-self.grad_output[k])))
			self.grad_total_loss = sum(self.grad_loss)

self.grad_loss.appendで複数の目的関数から計算される誤差の保存を実行します。

a_part_of_main.py
	def train(self):

		if self.synthetic:
			grads_and_vars = []
			for var in tf.trainable_variables():
				if 'synthetic' in var.name:
					grads_and_vars.append(self.optim.compute_gradients(self.grad_total_loss, var_list=[var])[0])
				else:
					for k in self.grad_output.keys():
						if k in var.name:
							grads = tf.gradients(self.layer.out[k], var, self.grad_output[k])[0]
							grads_and_vars.append((grads,var))
			# minimize the gradient loss and only change the dni module
			self.train_op = self.optim.apply_gradients(grads_and_vars, global_step=self.global_step)
		else:
			self.train_op = self.optim.minimize(self.loss, global_step=self.global_step)

		tf.initialize_all_variables().run()
		self.saver = tf.train.Saver(max_to_keep=self.max_to_keep)
		for epoch_idx in range(int(self.max_epoch)):
			for idx in range(int(math.floor(self.num_train/self.batch_size))):
				img_batch, label_batch = self.dataset.sequential_sample(self.batch_size)
				if self.synthetic:
					_, grad_loss, loss = self.sess.run([self.train_op, self.grad_total_loss, self.loss], {
								self.imgs: img_batch,
								self.labels: label_batch
								})
					print "[*] Iter {}, syn_grad_loss={}, real_loss={}".format(int(self.global_step.eval()), grad_loss, loss)
				else:
					_, loss = self.sess.run([self.train_op, self.loss],{
								self.imgs: img_batch,
								self.labels: label_batch
								})
					print "[*] Iter {}, real_loss={}".format(int(self.global_step.eval()), loss)

				if self.global_step.eval()%self.test_per_iter == 0 or self.global_step.eval()==1:
					self.evaluate(split='train')
					self.evaluate(split='test')

最初のif self.synthetic:の内側で、訓練可能なtf.Variableの変数たちを近似用とそれ以外の用途へと分離して、grad_and_varsのリストをappendしています。そして、最後にapply_gradientを呼び出しています。

9
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?