LoginSignup
0
0

More than 5 years have passed since last update.

Variable Computation の実装に関するメモ

Last updated at Posted at 2018-05-20

ポイント

  • LSTMをベースに Variable Computation (VCLSTM) を実装。
  • 今後、数値検証を行う。
  • 今後、(広い意味で)Adaptive Computation Time について整理する。

レファレンス

1. Variable Computation in Recurrent Neural Networks

モデル・アーキテクチャ

image.png
          (参照論文より引用)

m_t = \sigma(W_m[x_t, h_{t-1}] + b_m) \\
∀i\in1, ..., D \quad (e_t)_i = Thres_\epsilon\Bigl(\sigma\Bigl(\lambda\bigl(m_tD - i\bigr)\Bigr)\Bigr) \\
\bar{h}_{t-1} = e_t \odot h_{t-1},\quad \bar{c}_{t-1} = e_t \odot c_{t-1},\quad \bar{x}_t = e_t \odot x_t \\ 

\quad \\
\begin{pmatrix}
i_t \\
f_t \\
o_t \\
g_t \\
\end{pmatrix}
=
\begin{pmatrix}
\sigma(W_i[\bar{x_t},\bar{h}_{t-1}]+b_i) \\
\sigma(W_f[\bar{x_t},\bar{h}_{t-1}]+b_f) \\
\sigma(W_o[\bar{x_t},\bar{h}_{t-1}]+b_o) \\
tanh(W_g[\bar{x_t},\bar{h}_{t-1}]+b_g) \\
\end{pmatrix} \\
\tilde{c}_t = f_t * \bar{c}_{t-1} + i_t * g_t \\
\tilde{h}_t = o_t * tanh(\tilde{c}_t) \\

\quad \\
c_t = e_t \odot \tilde{c}_t + (1-e_t) \odot c_{t-1} \\
h_t = e_t \odot \tilde{h}_t + (1-e_t) \odot h_{t-1}

サンプルコード

  def VCLSTM(self, x, h, c, n_in, n_units, batch_size, \
                 lam, eps):
    w_m_x = self.weight_variable('w_m_x', [n_in, 1])
    w_m_h = self.weight_variable('w_m_h', [n_units, 1])
    b_m = self.bias_variable('b_m', [1])

    w_x = self.weight_variable('w_x', [n_in, n_units * 4])
    w_h = self.weight_variable('w_h', [n_units, n_units * 4])
    b = self.bias_variable('b', [n_units * 4])

    m = tf.sigmoid(tf.add(tf.add(tf.matmul(x, w_m_x), \
                tf.matmul(h, w_m_h)), b_m))

    a_h = m * n_units
    b_h = tf.convert_to_tensor(np.array(range(n_units)), \
                dtype = tf.float32)
    b_h = tf.tile(tf.expand_dims(b_h, axis = 0), \
                [batch_size, 1])
    e_h = tf.sigmoid(lam * (a_h - b_h)) - (1 - eps)
    e_h = tf.nn.relu(tf.sign(e_h))

    h_bar = h * e_h
    c_bar = c * e_h

    a_x = m * n_in
    b_x = tf.convert_to_tensor(np.array(range(n_in)), \
               dtype = tf.float32)
    b_x = tf.tile(tf.expand_dims(b_x, axis = 0), \
               [batch_size, 1])
    e_x = tf.sigmoid(lam * (a_x - b_x)) - (1 - eps)
    e_x = tf.nn.relu(tf.sign(e_x))

    x_bar = x * e_x

    i, f, o, g = tf.split(tf.add(tf.add(tf.matmul(x_bar, \
           w_x), tf.matmul(h_bar, w_h)), b), 4, axis = 1)

    i = tf.nn.sigmoid(i)
    f = tf.nn.sigmoid(f)
    o = tf.nn.sigmoid(o)
    g = tf.nn.tanh(g)

    c_tilde = tf.add(tf.multiply(f, c_bar), \
                 tf.multiply(i, g))
    h_tilde = tf.multiply(o, tf.nn.tanh(c_tilde))

    c = tf.add(tf.multiply(e_h, c_tilde), \
                 tf.multiply((1 - e_h), c))
    h = tf.add(tf.multiply(e_h, h_tilde), \
                 tf.multiply((1 - e_h), h))

    return h, c
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0