tf.Variableの一部のみを学習させたい,かつその対象をplaceholderなどで指定したい場合はtf.nn.embedding_lookup
やtf.gather
が使える.
たとえば,[TOTAL_SIZE, DIM]
の大きなVariableがあって,その一部のみ([BATCH_SIZE, DIM]
)を取り出してFully-Connected Layerへの入力とする場合は以下のようにすればできる.
embedding_test.py
import numpy as np
import tensorflow as tf
sess = tf.InteractiveSession()
TOTAL_SIZE = 10
BATCH_SIZE = 2
DIM = 3
LR = 0.01
# training target
np_x = np.random.rand(TOTAL_SIZE, DIM).astype(np.float32)
x = tf.Variable(np_x)
# indices
np_inds = np.array([0, 3])
inds = tf.placeholder(tf.int32, [BATCH_SIZE])
# part of x
h = tf.nn.embedding_lookup(x, inds)
# a fully connected layer
w = tf.Variable(tf.truncated_normal([DIM, DIM]))
y = tf.tanh(tf.matmul(h, w))
# target data
t = tf.convert_to_tensor(np.random.rand(BATCH_SIZE, DIM).astype(np.float32))
# loss
l = tf.reduce_mean(tf.square(y - t))
# get training op of x
tvars = [x]
grads = tf.gradients(l, tvars)
train = tf.train.GradientDescentOptimizer(LR).apply_gradients(zip(grads, tvars))
feed_dict = {inds: np_inds}
sess.run(tf.global_variables_initializer())
# check that embdding_lookup() slices correctly
diff = sess.run(h, feed_dict) - np_x[np_inds, :]
assert diff.sum() == 0.0
x_before_train = sess.run(x)
sess.run(train, feed_dict)
x_after_train = sess.run(x)
print(x_before_train - x_after_train)
実行すると,np_indsで指定したxの0と3番目の行だけが更新されているのがわかる:
[[ 0.00091529 -0.00065178 0.00180614]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0.00167947 -0.00147748 0.00317943]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]]
tf.gather
やtf.nn.embedding_lookup
はgradientを求めた場合tf.IndexedSlices
が得られる.おそらく大抵のoptimizerはサポートしているはずなのであまり困ることはないはず.しかし,特定のgradientに手を加えたい,かつそれがtf.IndexedSlices
であった場合はTensorとして扱えないので気をつける必要がある.たとえば拡大縮小したい場合は以下のような関数を用意すればエラーを回避できる:
scale_gradient.py
def scale_gradient(g, scale):
if isinstance(g, tf.IndexedSlices):
values = g.values * scale
return tf.IndexedSlices(values, g.indices)
else:
return g * scale