width = 32
inputs = keras.Input(shape=(width, ))
v = layers.Dense(width)(inputs)
k = layers.Dense(width)(inputs)
q = layers.Dense(width)(inputs)
v = keras.layers.Reshape((-1, 1))(v) # shape: [batch, 32, 1]
k = keras.layers.Reshape((-1, 1))(k) # shape: [batch, 32, 1]
q = keras.layers.Reshape((1, -1))(q) # shape: [batch, 1, 32]
kq = layers.Lambda(lambda x: tf.matmul(x[0], x[1]))([k, q]) # shape: [batch, 32, 32]
kq = layers.Softmax()(kq)
outputs = layers.Lambda(lambda x: tf.matmul(x[0], x[1]))([kq, v]) # shape: [batch, 32, 1]
model = models.Model(inputs, outputs)
間違ってたらコメントください。
参考文献