2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

LSTM & Attention の実装に関するメモ

Last updated at Posted at 2018-05-18

##ポイント

  • LSTMに2種類の Attention Mechanism (Type A、Type B) を実装し、(トイデータを用いて)パフォーマンスを検証。
  • Type A、Type B のどちらにおいても学習のパフォーマンスが向上することを確認。
  • Type A の方が Type B よりパフォーマンスが良い(今後、別タスク、別データで追加検証)。

##レファレンス
1. Frustratingly Short Attention Spans in Neural Language Model
##検証方法

  • モデルがデータの規則性を理解する(正解率98%以上に達する)までに要したイテレーション回数をチェックする。
  • 複数回の試行の平均と標準偏差をパフォーマンスの指標とする。

image.png
            (参照論文より引用)

Type A: (a) Neural Language Model with Attention
Type B: (c) Key-value-predict separation

##データ
入力 $x(t)$:

x(t) = \begin{pmatrix}
s(t)\\
m(t) 
\end{pmatrix} \\
s(t) \in \{0, 1, 2, 3, 4\}\\
m(t) \in \{0, 1\},\quad \sum_{t=1}^{T} m(t) = 2

出力 $y$:

y = \sum_{t=1}^{T} s(t) m(t)

例:

s = [2, 3, 1, 0, 4]\\
m = [0, 0, 1, 0, 1]\\
\rightarrow\
y = 5

##検証結果
数値計算例:

  • T = 5
  • サンプル数:1,000 (トレーニング:500、テスト:500)
  • バッチサイズ:30
  • 隠れ層のユニット数:30
  • オプティマイザー:Adam
  • 試行回数:5回
データ 指標 ベース Type A Type B
トレーニング 平均 319.40 166.60 297.00
標準偏差 26.00 13.95 40.04
テスト 平均 514.00 177.80 340.00
標準偏差 95.85 21.12 72.43

ベース: Attention なし

Attention Weights (Type A)

T-4 T-3 T-2 T-1
0.007 0.026 0.237 0.730

##サンプルコード
参照論文では Additive な(スコア)関数を使用。サンプルコードでは、Dot Product (Multiplicative) を使用。

  def LSTM(self, x, h, c, n_in, n_units):
    w_x = self.weight_variable('w_x', [n_in, n_units * 4])
    w_h = self.weight_variable('w_h', [n_units, n_units * 4])
    b = self.bias_variable('b', [n_units * 4])
    
    i, f, o, g = tf.split(tf.add(tf.add(tf.matmul(x, w_x), \
            tf.matmul(h, w_h)), b), 4, axis = 1)
    
    i = tf.nn.sigmoid(i)
    f = tf.nn.sigmoid(f)
    o = tf.nn.sigmoid(o)
    g = tf.nn.tanh(g)
    
    c = tf.add(tf.multiply(f, c), tf.multiply(i, g))
    h = tf.multiply(o, tf.nn.tanh(c))
    
    return h, c

  # Type A  
  def inference(self, x, length, n_in, n_units, n_out, \
                           batch_size):

    h = tf.zeros(shape = [batch_size, n_units], dtype = \
                           tf.float32)
    c = tf.zeros(shape = [batch_size, n_units], dtype = \
                           tf.float32)
    
    list_h = []
    list_c = []
    
    with tf.variable_scope('lstm'):
      for t in range(length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()
        
        h, c = self.LSTM(x[:, t, :], h, c, n_in, n_units)
       
        list_h.append(h)
        list_c.append(c)
        
    q = list_h[-1]
    q = tf.expand_dims(q, axis = 1)
    k = list_h[:-1]
    k = tf.transpose(k, [1, 0, 2])
    v = k
    p = q
    
    a = tf.matmul(q, tf.transpose(k, [0, 2, 1]))
    a = tf.nn.softmax(a)
    r = tf.matmul(a, v)
    
    with tf.variable_scope('merge'):
      w = self.weight_variable('w', [n_units * 2, n_units])
      
      h_ = tf.concat([r, p], axis = -1)
      h_ = tf.squeeze(h_) 
      h_ = tf.nn.tanh(tf.matmul(h_, w))
            
    with tf.variable_scope('linear'):
      w = self.weight_variable('w', [n_units, n_out])
      b = self.bias_variable('b', [n_out])

      pred = tf.add(tf.matmul(h_, w), b)
     
    return pred, a

  # Type B  
  def inference(self, x, length, n_in, n_units, n_out, \
                    batch_size):

    h = tf.zeros(shape = [batch_size, n_units * 3], \
                             dtype = tf.float32)
    c = tf.zeros(shape = [batch_size, n_units * 3], \
                             dtype = tf.float32)
    
    list_h = []
    list_h_k = []
    list_h_v = []
    list_h_p = []
    list_c = []
    
    with tf.variable_scope('lstm'):
      for t in range(length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()
        
        h, c = self.LSTM(x[:, t, :], h, c, n_in, n_units * 3)
        
        h_k, h_v, h_p = tf.split(h, 3, axis = -1)
        
        list_h.append(h)
        list_h_k.append(h_k)
        list_h_v.append(h_v)
        list_h_p.append(h_p)
        list_c.append(c)
        
    q = list_h_k[-1]
    q = tf.expand_dims(q, axis = 1)
    k = list_h_k[:-1]
    k = tf.transpose(k, [1, 0, 2])
    v = list_h_v[:-1]
    v = tf.transpose(v, [1, 0, 2])
    p = list_h_p[-1]
    p = tf.expand_dims(p, axis = 1)
    
    a = tf.matmul(q, tf.transpose(k, [0, 2, 1]))
    a = tf.nn.softmax(a)
    r = tf.matmul(a, v)
    
    with tf.variable_scope('merge'):
      w = self.weight_variable('w', [n_units * 2, n_units])
      
      h_ = tf.concat([r, p], axis = -1)
      h_ = tf.squeeze(h_) 
      h_ = tf.nn.tanh(tf.matmul(h_, w))
            
    with tf.variable_scope('linear'):
      w = self.weight_variable('w', [n_units, n_out])
      b = self.bias_variable('b', [n_out])

      pred = tf.add(tf.matmul(h_, w), b)
      
    return pred, a

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?