0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

記事投稿キャンペーン 「AI、機械学習」

【実装】T5(Text-to-Text Transfer Transformer)をスクラッチ開発してみた

Posted at

はじめに

今回はT5(Text-to-Text Transfer Transformer)という自然言語処理に関するモデルをゼロから実装してみました。
概要について理解している方は、コード部分からお読みください。

T5とは

T5(Text-to-Text Transfer Transformer)は、自然言語処理(NLP)タスクに対する汎用的なアーキテクチャであるTransformerをベースにしたモデルです。T5のアーキテクチャは比較的シンプルで、Encoder-DecoderのTransformerモデルに基づいています。大規模なテキストコーパスで事前訓練されており、多くの一般的な知識を獲得しています。これにより、特定のタスクに適応させる際の性能が向上します。

image.png

アーキテクチャ

基本的にはtransformerのエンコーダ・デコーダモデルの構造に似通っています。以下に各モジュールについて説明します。

image.png

inputs

EncoderLayerに入力するinputsは"Translate the following English text into French: Hello, how are you? target:"などの入力文をトークン化、idEmbedding層に入力しすることで埋め込みベクトルを得ます。

Encoder

Encoderは以下のモジュールで構成された各Layerで構成されます。

Multi-HeadAttention
Add&Norm
FeedForward

Multi-HeadAttention(Encoder)

Encoder内のMulti-HeadAttentionでは、query, key, valueが全てinputsの埋め込みベクトルから作成されています。このモジュールの役割はinputsの文章内に文脈、関係性などを把握するというものです。MultiHeadということもあり、各ヘッドごとに埋め込みベクトルの次元数より少ない次元数のquery, key, valueを作成し、最後にそれら全てを結合することで、異なる視点からの特徴を捉えることが出来ます。具体的な構造は以下のようになります。Dは埋め込みベクトルの次元数、Mはヘッドの個数です。

\begin{align}
query^{(m)}_{i} = W^{(m)}_{query}h_{i} \\
key^{(m)}_{j} = W^{(m)}_{key}h_{j}\\
value^{(m)}_{j} = W^{(m)}_{value}h_{j}\\
s^{(m)}_{ij} = \frac{query^{(m)}_{i}key_{j}^{(m)T}}{\sqrt{\frac{D}{M}}}\\
a^{(m)}_{ij} = \frac{exp(s^{(m)}_{ij})}{\sum_{j^{'}}^{\frac{D}{M}}exp(s^{(m)}_{ij^{'}})}\\
o_{i} = \sum_{j=1}^{\frac{D}{M}}a^{(m)}_{ij}value^{(m)}_{j}
\end{align}

Add&Norm

このモジュールはスキップ結合を使用して、Multi-HeadAttentionと一つ前の入力を加算することで勾配爆発、消失を防ぐ役割を果たし、またLayerNormalizationを使用することでパラメータを調整しています。

FeedForward

このモジュールは単なる全結合層と活性化関数の集合です。具体的には以下のような構成となっています。fはreluなどの活性化関数を使用します。

FFN(x) = W_{2}f(W_{1}x + b_{1})+b_{2}

outputs

これは"Bonjour, comment ça va ?"といったinputsの対する回答文や翻訳文で構成されています。inputsと同様にトークン化、idEmbedding層に入力し、埋め込みベクトルが得ます。

Decoder

Decoderは主に以下のモジュールが各Layerに構成されます。

MaskedMulti-HeadAttention
Add&Norm
Multi-HeadAttention
FeedForward

Decoderはこれらのモジュールで構成されたLayerが複数個集まったモジュールになります。

MaskedMulti-HeadAttention

image.png

このモジュールは、各位置に対してマスクを用いてその位置より未来の情報を隠し、複数のヘッドで同時に異なるアテンション重みを計算、それらを結合して出力します。

Multi-HeadAttention(Decoder)

基本的にはEncoderのMulti-HeadAttentionと同じ構造です。しかし、queryにoutputsの出力ベクトル、keyvalueにEncoderの最終Layerから抽出した出力ベクトルを使用します。これによりoutputsと関連性の高い単語や表現をinputsから抽出することが可能となります。

相対位置埋め込み

BertやGPTといった既存のtransfomerのEncoder、decoderは文章の位置における特徴を捉えるため、埋め込みベクトルに対して絶対位置埋め込みを行っていました。しかし今回のT5では、トークン同士の距離をとらえた相対位置埋め込みを使用しています。具体的には以下のような処理を行っています。相対位置におけるスコアはスカラー値となります。

s^{(m)}_{ij} = \frac{query^{(m)}_{i}・key^{(m)T}_{j}}{\sqrt{\frac{D}{M}}} + p^{(m)}_{\lvert i-j\rvert}

この相対位置埋め込みはEncoderのMulti-HeadAttentionとDecoderのMaskedMulti-HeadAttentionに使用されています。

コード

以下が実装したコードになります。

警告
以下は個人で開発したコードです。誤った部分や、冗長な部分がある可能性があります。

Add&Norm

Add&Norm
class AddNorm(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, check_encoder):
        super().__init__()
        self._setupAddNormModule(batch_size, seq_length, hidden_size, check_encoder)

    def _setupAddNormModule(self, batch_size, seq_length, hidden_size, check_encoder):
        if check_encoder: seq_length = seq_length[0]
        else: seq_length = seq_length[1]
        self.layer_norm = nn.LayerNorm((batch_size, seq_length, hidden_size))

    def forward(self, tokens, skipped_tokens):
        tokens += skipped_tokens
        tokens = self.layer_norm(tokens)

        return tokens

FeedForward

FeedForward
class FeedForward(nn.Module):
    def __init__(self, hidden_size, ffn_hidden_size):
        super().__init__()
        self._setupFeedForwardModule(hidden_size, ffn_hidden_size)

    def _setupFeedForwardModule(self, hidden_size, ffn_hidden_size):
        dense1 = nn.Linear(hidden_size, ffn_hidden_size)
        relu1 = nn.ReLU()
        dense2 = nn.Linear(ffn_hidden_size, hidden_size)
        self.feed_foward_module = nn.ModuleList([dense1, relu1, dense2])

    def forward(self, tokens):
        for module in self.feed_foward_module:
            tokens = module(tokens)
        return tokens

Multi-HeadAttention(マスク化も可能)

class MultiHeadAttention(nn.Module):
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, check_positional_embedding, check_mask):
        super().__init__()
        self._setupHeadQKV(num_heads, hidden_size)
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.check_positional_embedding = check_positional_embedding
        self.check_mask = check_mask
        self.softmax = nn.Softmax(dim=-1)

    #ヘッド毎にquery, key, valueを設定
    def _setupHeadQKV(self, num_heads, hidden_size):
        query_module = []
        key_module = []
        value_module = []
        head_hidden_size = int(hidden_size / num_heads)

        for _ in range(num_heads):
            query_module.append(nn.Linear(hidden_size, head_hidden_size))
            key_module.append(nn.Linear(hidden_size, head_hidden_size))
            value_module.append(nn.Linear(hidden_size, head_hidden_size))

        self.query_module = nn.ModuleList(query_module)
        self.key_module = nn.ModuleList(key_module)
        self.value_module = nn.ModuleList(value_module)

    #相対位置埋め込みのスカラー値を計算
    def _outputRelativePositionalEmbeddingScalar(self, query, batch_size, seq_length, hidden_size, num_heads):
        if self.check_mask: seq_length = self.seq_length[1]
        else: seq_length = self.seq_length[0]

        embed_Module = []
        head_hidden_size = int(hidden_size / num_heads)
        position_ids = torch.tensor(list(range(seq_length)), dtype=torch.long).reshape(1, seq_length).expand(batch_size, seq_length)
        for id in range(num_heads): embed_Module.append(nn.Embedding(seq_length, head_hidden_size))
        self.embed_module = nn.ModuleList(embed_Module)
        for id in range(num_heads):
            head_query = self.query_module[id](query)
            tmp_relative_position_embedding_scalar = (head_query@(self.embed_module[id](position_ids).transpose(1, 2))).reshape(1, batch_size, seq_length, seq_length)
            if id == 0: relative_position_embedding_scalar = tmp_relative_position_embedding_scalar
            else: relative_position_embedding_scalar = torch.concat([relative_position_embedding_scalar, tmp_relative_position_embedding_scalar], dim=0)

        return relative_position_embedding_scalar
    
    #Attentionモジュールの出力
    def _outputAttention(self, query, key, value, batch_size, seq_length, hidden_size, num_heads, check_positional_embedding, check_mask, encoder_attention_mask, decoder_attention_mask):
        #どのAttentionモジュールかを判断、系列長を設定
        if check_positional_embedding:
            if check_mask:
                seq_length1 = seq_length2 = self.seq_length[1]
            else:
                seq_length1 = seq_length2 = self.seq_length[0]

        else:
            seq_length1 = self.seq_length[1]
            seq_length2 = self.seq_length[0]

        head_hidden_size = int(hidden_size / num_heads)
        encoder_attention_mask = encoder_attention_mask.reshape(batch_size, -1, 1).expand(batch_size, seq_length1, seq_length2)
        decoder_attention_mask = decoder_attention_mask.reshape(batch_size, 1, -1).expand(batch_size, seq_length1, seq_length2)
        #paddingされたベクトルを除外するmap
        padding_map = encoder_attention_mask * decoder_attention_mask
        #マスク化に使用するmap
        mask_map = torch.tensor(np.tril(np.ones((seq_length1, seq_length2))), dtype=torch.long)
        #相対位置埋め込みの計算
        if check_positional_embedding:relative_position_embedding_scalar = self._outputRelativePositionalEmbeddingScalar(query, batch_size, seq_length, hidden_size, num_heads)
        else: relative_position_embedding_scalar = torch.zeros_like(padding_map, dtype=torch.float).expand(num_heads, batch_size, seq_length1, seq_length2)

        for id in range(num_heads):
            head_query = self.query_module[id](query)
            head_key = self.key_module[id](key)
            head_value = self.value_module[id](value)
            #Attentionの計算
            if check_mask: tmp_head_attention = self.softmax(padding_map * (mask_map * (head_query@head_key.transpose(1, 2)) / torch.sqrt(torch.tensor(head_hidden_size)) + relative_position_embedding_scalar[id]))@head_value
            else:tmp_head_attention = self.softmax(padding_map * ((head_query@head_key.transpose(1, 2)) / torch.sqrt(torch.tensor(head_hidden_size)) + relative_position_embedding_scalar[id]))@head_value
            #各ヘッドの出力を結合
            if id == 0: head_attention = tmp_head_attention
            else: head_attention = torch.concat([head_attention, tmp_head_attention], dim=-1)
        output_attention = head_attention

        return output_attention

    def forward(self, query, key, value, encoder_attention_mask, decoder_attention_mask):
        output_attention = self._outputAttention(query, key, value, self.batch_size, self.seq_length, self.hidden_size, self.num_heads,
                                                  self.check_positional_embedding, self.check_mask, encoder_attention_mask, decoder_attention_mask)

        return output_attention

EncoderLayer

EncoderLayer
class EncoderLayer(nn.Module):
    #各モジュールを設定
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=False)
        self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
        self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, tokens, encoder_attention_mask):
        skip1 = tokens
        multi_head_attention = self.multi_head_attention(tokens, tokens, tokens, encoder_attention_mask, encoder_attention_mask)
        add_norm1 = self.add_norm1(self.dropout(multi_head_attention), skip1)
        skip2 = add_norm1
        feed_forward = self.feed_forward(add_norm1)
        add_norm2 = self.add_norm2(self.dropout(feed_forward), skip2)
        tokens = add_norm2

        return tokens

Encoder

Encoder
class Encoder(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self._setupEncoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    #EncoderLayerを指定回数重ねる
    def _setupEncoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        encoder_layer_list = []
        for _ in range(num_layer):
            encoder_layer = EncoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
            encoder_layer_list.append(encoder_layer)
        self.encoder_module = nn.ModuleList(encoder_layer_list)

    def forward(self, encoder_embedding, encoder_attention_mask):
        tokens = encoder_embedding
        for encoder_layer in self.encoder_module:
            tokens = encoder_layer(tokens, encoder_attention_mask)
        output_encoder = tokens

        return output_encoder

DecoderLayer

DecoderLayer
    
class DecoderLayer(nn.Module):
    #各モジュールを設定
    def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
        super().__init__()
        self.masked_multi_head_attention =  MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=True)
        self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
        self.cross_multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=False)
        self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
        self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
        self.add_norm3 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, tokens, output_encoder, encoder_attention_mask, decoder_attention_mask):
        skip1 = tokens
        masked_multi_head_attention = self.masked_multi_head_attention(tokens, tokens, tokens, decoder_attention_mask, decoder_attention_mask)
        add_norm1 = self.add_norm1(self.dropout(masked_multi_head_attention), skip1)
        skip2 = add_norm1
        cross_multi_head_attention = self.cross_multi_head_attention(tokens, output_encoder, output_encoder, decoder_attention_mask, encoder_attention_mask)
        add_norm2 = self.add_norm2(self.dropout(cross_multi_head_attention), skip2)
        skip3 = add_norm2
        feed_forward = self.feed_forward(tokens)
        add_norm3 = self.add_norm3(self.dropout(feed_forward), skip3)
        tokens = add_norm3

        return tokens

Decoder

Decoder
class Decoder(nn.Module):
    def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self._setupDecoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)

    #DecoderLayerを指定回数重ねる
    def _setupDecoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        decoder_list = []
        for _ in range(num_layer):
            decoder_layer = DecoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
            decoder_list.append(decoder_layer)
        self.decoder_module = nn.ModuleList(decoder_list)

    def forward(self, decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask):
        tokens = decoder_embedding
        for decoder_layer in self.decoder_module:
            tokens = decoder_layer(tokens, output_encoder, encoder_attention_mask, decoder_attention_mask)
        output_decoder = tokens

        return output_decoder

T5(Encoder-Decoder)

T5
class T5(nn.Module):
    def __init__(self, batch_size, num_embedding, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
        super().__init__()
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.num_heads = num_heads
        self.ffn_hidden_size = ffn_hidden_size

        #Encoder, Decoderの埋め込み層を設定
        self.encoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
        self.decoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
        self.encoder = Encoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
        self.decoder = Decoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
        


    def forward(self, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask):
        encoder_embedding = self.encoder_embedding(encoder_input_ids)
        output_encoder = self.encoder(encoder_embedding, encoder_attention_mask)
        decoder_embedding = self.decoder_embedding(decoder_input_ids)
        decoder_output = self.decoder(decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask)

        return decoder_output #下流タスクは設定していない

実行結果

以下がT5を実行した結果になります。

実行
from transformers import T5Tokenizer
from pprint import pprint
impot torch
import torch.nn as nn
import numpy as np

input = ["Translate the following English text into French: Hello, how are you? target:",
         "Translate the following English text into French: Good morning, everyone. target",
          "Translate the following English text into French: Can you help me with this? target" ]

target = ["Bonjour, comment ça va ?",
          "Bonjour à tous.",
          "Pouvez-vous m'aider avec ceci ?"]

tokenizer = T5Tokenizer.from_pretrained("t5-small")

encoder_tokenize = tokenizer(input, return_tensors="pt", padding=True)
decoder_tokenize = tokenizer(target, return_tensors="pt", padding=True)

encoder_input_ids = encoder_tokenize.input_ids
encoder_attention_mask = encoder_tokenize.attention_mask
decoder_input_ids = decoder_tokenize.input_ids
decoder_attention_mask = decoder_tokenize.attention_mask

batch_size = encoder_input_ids.size(0)
encoder_seq_length = encoder_input_ids.size(1)
decoder_seq_length = decoder_input_ids.size(1)
max_ids = torch.max(torch.max(encoder_input_ids), torch.max(decoder_input_ids)) #仮の埋め込み辞書数

kwargs = {
    "batch_size": batch_size,
    "num_embedding": max_ids + 1,
    "seq_length": (encoder_seq_length, decoder_seq_length),
    "hidden_size": 768,
    "num_layer": 12,
    "num_heads": 12,
    "ffn_hidden_size": 3072,

}
    
model = T5(**kwargs)
outputs = model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask)
print(model)
print(model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask).shape)
結果
T5(
  (encoder_embedding): Embedding(30356, 768, padding_idx=0)
  (decoder_embedding): Embedding(30356, 768, padding_idx=0)
  (encoder): Encoder(
    (encoder_module): ModuleList(
      (0-11): 12 x EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (query_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (key_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (value_module): ModuleList(
            (0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
          )
          (softmax): Softmax(dim=-1)
          (embed_module): ModuleList(
            (0-11): 12 x Embedding(18, 64)
          )
        )
        (add_norm1): AddNorm(
          (layer_norm): LayerNorm((3, 18, 768), eps=1e-05, elementwise_affine=True)
        )
        (feed_forward): FeedForward(
...
    )
  )
)
torch.Size([3, 13, 768])

無事T5のモデルを実行、出力することが出来ました。

まとめ

T5の実装を通して、Attention機構の実装の難しさを再度実感しました。相対位置埋め込みやマスク化といった処理を加えることでViTと比べて実装の難易度が何倍にも跳ね上がったように感じました。次は自然言語処理から少し離れてVision-Languageのモデルについて勉強、実装したいと思います。

参考文献

Transformerとは?AI自然言語学習の技術を解説
Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?