LoginSignup
11
9

More than 5 years have passed since last update.

tf.nn.embedding_lookup()を使う

Last updated at Posted at 2017-07-07

tf.Variableの一部のみを学習させたい,かつその対象をplaceholderなどで指定したい場合はtf.nn.embedding_lookuptf.gatherが使える.

たとえば,[TOTAL_SIZE, DIM]の大きなVariableがあって,その一部のみ([BATCH_SIZE, DIM])を取り出してFully-Connected Layerへの入力とする場合は以下のようにすればできる.

embedding_test.py
import numpy as np
import tensorflow as tf

sess = tf.InteractiveSession()

TOTAL_SIZE = 10
BATCH_SIZE = 2
DIM = 3
LR = 0.01


# training target
np_x = np.random.rand(TOTAL_SIZE, DIM).astype(np.float32)
x = tf.Variable(np_x)

# indices 
np_inds = np.array([0, 3])
inds = tf.placeholder(tf.int32, [BATCH_SIZE])

# part of x
h = tf.nn.embedding_lookup(x, inds)

# a fully connected layer
w = tf.Variable(tf.truncated_normal([DIM, DIM]))
y = tf.tanh(tf.matmul(h, w))

# target data
t = tf.convert_to_tensor(np.random.rand(BATCH_SIZE, DIM).astype(np.float32))

# loss
l = tf.reduce_mean(tf.square(y - t))

# get training op of x
tvars = [x]
grads = tf.gradients(l, tvars)
train = tf.train.GradientDescentOptimizer(LR).apply_gradients(zip(grads, tvars))

feed_dict = {inds: np_inds}

sess.run(tf.global_variables_initializer())

# check that embdding_lookup() slices correctly
diff = sess.run(h, feed_dict) - np_x[np_inds, :]
assert diff.sum() == 0.0

x_before_train = sess.run(x)

sess.run(train, feed_dict)

x_after_train = sess.run(x)

print(x_before_train - x_after_train)

実行すると,np_indsで指定したxの0と3番目の行だけが更新されているのがわかる:

[[ 0.00091529 -0.00065178  0.00180614]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.00167947 -0.00147748  0.00317943]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]]

tf.gathertf.nn.embedding_lookupはgradientを求めた場合tf.IndexedSlicesが得られる.おそらく大抵のoptimizerはサポートしているはずなのであまり困ることはないはず.しかし,特定のgradientに手を加えたい,かつそれがtf.IndexedSlicesであった場合はTensorとして扱えないので気をつける必要がある.たとえば拡大縮小したい場合は以下のような関数を用意すればエラーを回避できる:

scale_gradient.py
def scale_gradient(g, scale):
    if isinstance(g, tf.IndexedSlices):
        values = g.values * scale
        return tf.IndexedSlices(values, g.indices)
    else:
        return g * scale
11
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
9