80
96

More than 1 year has passed since last update.

Transformerの構造を理解したい

Last updated at Posted at 2023-04-17

こんにちは、すきにーです。

Transformerのこの図がずっと理解できなかったのですが、
最近理解することができました。

この記事で自分なりに解説していきたいと思います。

transfrormer

コードはかなり簡略化して書いてます。

はじめに

この記事で説明していること

TensorFlow公式ドキュメント, Transformerモデル
https://www.tensorflow.org/tutorials/text/transformer

エンコーダーとデコーダーからTransformerの作成のところまでを説明しています。
 

参考にした記事

作って理解する Transformer / Attention
深層学習界の大前提Transformerの論文解説!
【Pytorch】Transformerを実装する
図で理解するTransformer

Transformerの構造

全体図

image.png

入力 → Encoder → Decoder → 出力という構造になっています。
image.png
出典:https://kikaben.com/transformers-encoder-decoder/

エンコーダーとデコーダーの中身が複雑になっているので、分かり辛くなっているのだと思いました。

簡単に書くとこんな感じ

class Transformer(tf.keras.Model):
  def __init__():
    self.encoder = Encoder()
    self.decoder = Decoder()
    self.final_layer = tf.keras.layers.Dense()

  def call(self, encoder_input, decoder_input):
    #エンコーダーの出力
    encoder_output = self.encoder(encoder_input)  

    #エンコーダーの出力をデコーダーにいれて、デコーダから出力
    decoder_output, attention_weights = self.decoder(encoder_output, decoder_input) 

    #最終的な出力
    output = self.final_layer(decoder_output)  

    return output, attention_weights

Encoder

エンコーダーはこの図の左側です。
image.png
そして、青枠の部分をエンコーダーレイヤー(EncoderLayer)と言います
image.png

エンコーダーレイヤーを1つの層として考えると、エンコーダはこのような構造になっています。
①入力の埋め込み

②位置エンコーディング

③エンコーダーレイヤー × N

④出力

pythonで書くと次のようになります。

class Encoder(tf.keras.layers.Layer):
  def __init__():
    super(Encoder, self).__init__()
    self.d_model = d_model

    #入力の埋め込み
    self.embedding = tf.keras.layers.Embedding()

    #位置エンコーディング
    self.pos_encoding = positional_encoding()

    #エンコーダーレイヤー
    self.enc_layers = EncoderLayer() 

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # 入力埋め込み
    x = self.embedding(x) 

    #埋め込みと位置エンコーディングを合算する
    x += self.pos_encoding[:, :seq_len, :]

    #エンコーダーレイヤーの処理をN回繰り返す
    for i in range(N):
      x = self.enc_layers[i](x, training, mask)
    
    #出力
    return x  

EcoderLayer

下図がエンコーダーレイヤーです。
image.png

主にやっていることは2つしかありません。
①マルチヘッドアテンション
②ポイントワイズ・フィードフォワード・ネットワーク

Add&Norm の部分は勾配消失の回避や学習の安定化のためにやっている処理なので、今回は飛ばします。

class EncoderLayer(tf.keras.layers.Layer):
  def __init__():
    super(EncoderLayer, self).__init__()
    self.mha = MultiHeadAttention()
    self.ffn = point_wise_feed_forward_network()

  def call(input):

    # ①マルチヘッドアテンション
    attntion_output, _ = self.mha(input) 

    # ②ポイントワイズ・フィードフォワード・ネットワーク
    output = self.ffn(attntion_output)

    return output

Decoder

デコーダーは、エンコーダーとほとんど同じです。
処理が一つ増えてるだけ。

デコーダーはこの図の右側になります。
image.png

そして、青枠の部分をデコーダーレイヤー(DecoderLayer)と言います
image.png

デコーダーレイヤーを1つの層として考えると、デコーダはこのような構造になっています。
①入力の埋め込み

②位置デコーディング

③デコーダーレイヤー × N

④出力

pythonで書くと次のようになります。

class Decoder(tf.keras.layers.Layer):
  def __init__():
    super(Decoder, self).__init__()
    self.d_model = d_model

    #入力の埋め込み
    self.embedding = tf.keras.layers.Embedding()

    #位置デコーディング
    self.pos_Decoding = positional_Decoding()

    #デコーダーレイヤー
    self.dec_layers = DecoderLayer() 

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # 入力埋め込み
    x = self.embedding(x) 

    #埋め込みと位置デコーディングを合算する
    x += self.pos_Decoding[:, :seq_len, :]

    #デコーダーレイヤーの処理をN回繰り返す
    for i in range(N):
      x, weights = self.dec_layers[i](x, training, mask)
    
    #出力
    return x, weights

DecoderLayer

下図がデコーダーレイヤーです。
image.png

エンコーダーレイヤーと違うところは、2つしかありません。

  1. エンコーダーの出力を受け取って処理する層が追加されている
  2. マルチヘッドアテンションが未来の情報にアクセスできないように制限されている。(Masked Multi-Head Attentionになっている)

デコーダーレイヤーの処理は3つです。
①Masked Multi-Head Attention
②エンコーダからの出力をマルチヘッドアテンションで処理
③ポイントワイズ・フィードフォワード・ネットワーク

Add&Norm の部分は勾配消失の回避や学習の安定化のためにやっている処理なので、今回は飛ばします。

class DecoderLayer(tf.keras.layers.Layer):
  def __init__():
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention()
    self.mha2 = MultiHeadAttention()

    self.ffn = point_wise_feed_forward_network()

  def call(input, encoder_output, mask):

    # ①Masked Multi-Head Attention
    attntion_output, attention_weights1 = self.mha1(input, mask) 

    # ②エンコーダからの出力をマルチヘッドアテンションで処理
    attntion_output, attention_weights2 = self.mha(attntion_output, encoder_output, mask) 

    # ③ポイントワイズ・フィードフォワード・ネットワーク
    output = self.ffn(attntion_output)

    return output, attention_weights1, attention_weights2

以上です。

まとめ

Transformerの全体図
image.png

エンコーダーの構成
image.png

エンコーダーレイヤー
image.png

デコーダーの構成
image.png

デコーダーレイヤー
image.png

80
96
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
80
96