#1 Scaled Dot-Product Attention
Scaled Dot-Production AttentionのAttention関数は、Query、Key、Valueを入力とする以下の関数である。
#2 コード
Tensorflowチュートリアルに記載のあるScaled Dot-Product Attentionメソッドの実装は以下。
import tensorflow as tf
#############################################
#
# Scaled Dot Product Attention
# Attention(Q, K, V) = softmax( Q*K.T / sqrt(d)) * V
#
def scaled_dot_product_attention(q, k, v):
# Q * K.T
matmul_qk = tf.matmul(q, k, transpose_b=True)
# Q*K.T / sqrt(d)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# softmax( Q*K.T / sqrt(d) )
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# softmax( Q*K.T / sqrt(d) ) * V
output = tf.matmul(attention_weights, v)
return output, attention_weights
QとKの転置の内積を計算
matmul_qk = tf.matmul(q, k, transpose_b=True)
QとKの転置の内積をルートdkで割る
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
softmaxを計算
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
softmaxの結果とVの内積を計算
output = tf.matmul(attention_weights, v)
図中のMask(opt.)はオプションのため省略している。
##2.2 テストコード
K, V, Qの値を指定して、scaled_dot_product_attentionメソッドを呼び出すコード。
#############################################
#
# Test "Scaled Dot Product Attention" method
#
k = tf.constant([[10, 0, 0],
[ 0,10, 0],
[ 0, 0,10],
[ 0, 0,10]], dtype=tf.float32)
v = tf.constant([[ 1, 0],
[ 10, 0],
[ 100, 5],
[ 1000, 6]], dtype=tf.float32)
q = tf.constant([[0, 10, 0]], dtype=tf.float32)
print('---input---')
print(k)
print(v)
print(q)
result, attention_w = scaled_dot_product_attention(q,k,v)
print('---result---')
print(attention_w)
print(result)
k、v、qを入力して、scaled_dot_product_attentionメソッドを呼び出し、出力結果(result)を表示する。
#3 実行
テストコードを実行すると以下のような結果となる。
(1.00000e+01 9.276602e-25)が出力結果である。
#4 環境
Version | |
---|---|
Python | 3.7.4 |
Tensorflow | 2.3.1 |
#5 参考
URL | |
---|---|
Transformer論文 | https://arxiv.org/abs/1706.03762 |
Tensorflowチュートリアル | https://www.tensorflow.org/tutorials/text/transformer |