こんにちは、すきにーです。
Transformerのこの図がずっと理解できなかったのですが、
最近理解することができました。
この記事で自分なりに解説していきたいと思います。
コードはかなり簡略化して書いてます。
はじめに
この記事で説明していること
TensorFlow公式ドキュメント, Transformerモデル
https://www.tensorflow.org/tutorials/text/transformer
エンコーダーとデコーダーからTransformerの作成のところまでを説明しています。
参考にした記事
作って理解する Transformer / Attention
深層学習界の大前提Transformerの論文解説!
【Pytorch】Transformerを実装する
図で理解するTransformer
Transformerの構造
全体図
入力 → Encoder → Decoder → 出力という構造になっています。
出典:https://kikaben.com/transformers-encoder-decoder/
エンコーダーとデコーダーの中身が複雑になっているので、分かり辛くなっているのだと思いました。
簡単に書くとこんな感じ
class Transformer(tf.keras.Model):
def __init__():
self.encoder = Encoder()
self.decoder = Decoder()
self.final_layer = tf.keras.layers.Dense()
def call(self, encoder_input, decoder_input):
#エンコーダーの出力
encoder_output = self.encoder(encoder_input)
#エンコーダーの出力をデコーダーにいれて、デコーダから出力
decoder_output, attention_weights = self.decoder(encoder_output, decoder_input)
#最終的な出力
output = self.final_layer(decoder_output)
return output, attention_weights
Encoder
エンコーダーはこの図の左側です。
そして、青枠の部分をエンコーダーレイヤー(EncoderLayer)と言います
エンコーダーレイヤーを1つの層として考えると、エンコーダはこのような構造になっています。
①入力の埋め込み
↓
②位置エンコーディング
↓
③エンコーダーレイヤー × N
↓
④出力
pythonで書くと次のようになります。
class Encoder(tf.keras.layers.Layer):
def __init__():
super(Encoder, self).__init__()
self.d_model = d_model
#入力の埋め込み
self.embedding = tf.keras.layers.Embedding()
#位置エンコーディング
self.pos_encoding = positional_encoding()
#エンコーダーレイヤー
self.enc_layers = EncoderLayer()
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# 入力埋め込み
x = self.embedding(x)
#埋め込みと位置エンコーディングを合算する
x += self.pos_encoding[:, :seq_len, :]
#エンコーダーレイヤーの処理をN回繰り返す
for i in range(N):
x = self.enc_layers[i](x, training, mask)
#出力
return x
EcoderLayer
主にやっていることは2つしかありません。
①マルチヘッドアテンション
②ポイントワイズ・フィードフォワード・ネットワーク
Add&Norm の部分は勾配消失の回避や学習の安定化のためにやっている処理なので、今回は飛ばします。
class EncoderLayer(tf.keras.layers.Layer):
def __init__():
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention()
self.ffn = point_wise_feed_forward_network()
def call(input):
# ①マルチヘッドアテンション
attntion_output, _ = self.mha(input)
# ②ポイントワイズ・フィードフォワード・ネットワーク
output = self.ffn(attntion_output)
return output
Decoder
デコーダーは、エンコーダーとほとんど同じです。
処理が一つ増えてるだけ。
そして、青枠の部分をデコーダーレイヤー(DecoderLayer)と言います
デコーダーレイヤーを1つの層として考えると、デコーダはこのような構造になっています。
①入力の埋め込み
↓
②位置デコーディング
↓
③デコーダーレイヤー × N
↓
④出力
pythonで書くと次のようになります。
class Decoder(tf.keras.layers.Layer):
def __init__():
super(Decoder, self).__init__()
self.d_model = d_model
#入力の埋め込み
self.embedding = tf.keras.layers.Embedding()
#位置デコーディング
self.pos_Decoding = positional_Decoding()
#デコーダーレイヤー
self.dec_layers = DecoderLayer()
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# 入力埋め込み
x = self.embedding(x)
#埋め込みと位置デコーディングを合算する
x += self.pos_Decoding[:, :seq_len, :]
#デコーダーレイヤーの処理をN回繰り返す
for i in range(N):
x, weights = self.dec_layers[i](x, training, mask)
#出力
return x, weights
DecoderLayer
エンコーダーレイヤーと違うところは、2つしかありません。
- エンコーダーの出力を受け取って処理する層が追加されている
- マルチヘッドアテンションが未来の情報にアクセスできないように制限されている。(Masked Multi-Head Attentionになっている)
デコーダーレイヤーの処理は3つです。
①Masked Multi-Head Attention
②エンコーダからの出力をマルチヘッドアテンションで処理
③ポイントワイズ・フィードフォワード・ネットワーク
Add&Norm の部分は勾配消失の回避や学習の安定化のためにやっている処理なので、今回は飛ばします。
class DecoderLayer(tf.keras.layers.Layer):
def __init__():
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention()
self.mha2 = MultiHeadAttention()
self.ffn = point_wise_feed_forward_network()
def call(input, encoder_output, mask):
# ①Masked Multi-Head Attention
attntion_output, attention_weights1 = self.mha1(input, mask)
# ②エンコーダからの出力をマルチヘッドアテンションで処理
attntion_output, attention_weights2 = self.mha(attntion_output, encoder_output, mask)
# ③ポイントワイズ・フィードフォワード・ネットワーク
output = self.ffn(attntion_output)
return output, attention_weights1, attention_weights2
以上です。