TensorFlow

tf.nn.embedding_lookup()を使う

tf.Variableの一部のみを学習させたい,かつその対象をplaceholderなどで指定したい場合はtf.nn.embedding_lookuptf.gatherが使える.

たとえば,[TOTAL_SIZE, DIM]の大きなVariableがあって,その一部のみ([BATCH_SIZE, DIM])を取り出してFully-Connected Layerへの入力とする場合は以下のようにすればできる.

embedding_test.py
import numpy as np
import tensorflow as tf

sess = tf.InteractiveSession()

TOTAL_SIZE = 10
BATCH_SIZE = 2
DIM = 3
LR = 0.01


# training target
np_x = np.random.rand(TOTAL_SIZE, DIM).astype(np.float32)
x = tf.Variable(np_x)

# indices 
np_inds = np.array([0, 3])
inds = tf.placeholder(tf.int32, [BATCH_SIZE])

# part of x
h = tf.nn.embedding_lookup(x, inds)

# a fully connected layer
w = tf.Variable(tf.truncated_normal([DIM, DIM]))
y = tf.tanh(tf.matmul(h, w))

# target data
t = tf.convert_to_tensor(np.random.rand(BATCH_SIZE, DIM).astype(np.float32))

# loss
l = tf.reduce_mean(tf.square(y - t))

# get training op of x
tvars = [x]
grads = tf.gradients(l, tvars)
train = tf.train.GradientDescentOptimizer(LR).apply_gradients(zip(grads, tvars))

feed_dict = {inds: np_inds}

sess.run(tf.global_variables_initializer())

# check that embdding_lookup() slices correctly
diff = sess.run(h, feed_dict) - np_x[np_inds, :]
assert diff.sum() == 0.0

x_before_train = sess.run(x)

sess.run(train, feed_dict)

x_after_train = sess.run(x)

print(x_before_train - x_after_train)

実行すると,np_indsで指定したxの0と3番目の行だけが更新されているのがわかる:

[[ 0.00091529 -0.00065178  0.00180614]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.00167947 -0.00147748  0.00317943]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]]

tf.gathertf.nn.embedding_lookupはgradientを求めた場合tf.IndexedSlicesが得られる.おそらく大抵のoptimizerはサポートしているはずなのであまり困ることはないはず.しかし,特定のgradientに手を加えたい,かつそれがtf.IndexedSlicesであった場合はTensorとして扱えないので気をつける必要がある.たとえば拡大縮小したい場合は以下のような関数を用意すればエラーを回避できる:

scale_gradient.py
def scale_gradient(g, scale):
    if isinstance(g, tf.IndexedSlices):
        values = g.values * scale
        return tf.IndexedSlices(values, g.indices)
    else:
        return g * scale