はじめに
今回はT5(Text-to-Text Transfer Transformer)という自然言語処理に関するモデルをゼロから実装してみました。
概要について理解している方は、コード部分からお読みください。
T5とは
T5(Text-to-Text Transfer Transformer)は、自然言語処理(NLP)タスクに対する汎用的なアーキテクチャであるTransformerをベースにしたモデルです。T5のアーキテクチャは比較的シンプルで、Encoder-DecoderのTransformerモデルに基づいています。大規模なテキストコーパスで事前訓練されており、多くの一般的な知識を獲得しています。これにより、特定のタスクに適応させる際の性能が向上します。
アーキテクチャ
基本的にはtransformerのエンコーダ・デコーダモデルの構造に似通っています。以下に各モジュールについて説明します。
inputs
EncoderLayerに入力するinputsは"Translate the following English text into French: Hello, how are you? target:"
などの入力文をトークン化、id
をEmbedding
層に入力しすることで埋め込みベクトルを得ます。
Encoder
Encoderは以下のモジュールで構成された各Layerで構成されます。
Multi-HeadAttention
Add&Norm
FeedForward
Multi-HeadAttention(Encoder)
Encoder内のMulti-HeadAttentionでは、query, key, valueが全てinputsの埋め込みベクトルから作成されています。このモジュールの役割はinputsの文章内に文脈、関係性などを把握するというものです。MultiHeadということもあり、各ヘッドごとに埋め込みベクトルの次元数より少ない次元数のquery, key, valueを作成し、最後にそれら全てを結合することで、異なる視点からの特徴を捉えることが出来ます。具体的な構造は以下のようになります。D
は埋め込みベクトルの次元数、M
はヘッドの個数です。
\begin{align}
query^{(m)}_{i} = W^{(m)}_{query}h_{i} \\
key^{(m)}_{j} = W^{(m)}_{key}h_{j}\\
value^{(m)}_{j} = W^{(m)}_{value}h_{j}\\
s^{(m)}_{ij} = \frac{query^{(m)}_{i}key_{j}^{(m)T}}{\sqrt{\frac{D}{M}}}\\
a^{(m)}_{ij} = \frac{exp(s^{(m)}_{ij})}{\sum_{j^{'}}^{\frac{D}{M}}exp(s^{(m)}_{ij^{'}})}\\
o_{i} = \sum_{j=1}^{\frac{D}{M}}a^{(m)}_{ij}value^{(m)}_{j}
\end{align}
Add&Norm
このモジュールはスキップ結合を使用して、Multi-HeadAttentionと一つ前の入力を加算することで勾配爆発、消失を防ぐ役割を果たし、またLayerNormalizationを使用することでパラメータを調整しています。
FeedForward
このモジュールは単なる全結合層と活性化関数の集合です。具体的には以下のような構成となっています。f
はreluなどの活性化関数を使用します。
FFN(x) = W_{2}f(W_{1}x + b_{1})+b_{2}
outputs
これは"Bonjour, comment ça va ?"
といったinputsの対する回答文や翻訳文で構成されています。inputsと同様にトークン化、id
をEmbedding
層に入力し、埋め込みベクトルが得ます。
Decoder
Decoderは主に以下のモジュールが各Layerに構成されます。
MaskedMulti-HeadAttention
Add&Norm
Multi-HeadAttention
FeedForward
Decoderはこれらのモジュールで構成されたLayerが複数個集まったモジュールになります。
MaskedMulti-HeadAttention
このモジュールは、各位置に対してマスクを用いてその位置より未来の情報を隠し、複数のヘッドで同時に異なるアテンション重みを計算、それらを結合して出力します。
Multi-HeadAttention(Decoder)
基本的にはEncoderのMulti-HeadAttentionと同じ構造です。しかし、query
にoutputsの出力ベクトル、key
とvalue
にEncoderの最終Layerから抽出した出力ベクトルを使用します。これによりoutputsと関連性の高い単語や表現をinputsから抽出することが可能となります。
相対位置埋め込み
BertやGPTといった既存のtransfomerのEncoder、decoderは文章の位置における特徴を捉えるため、埋め込みベクトルに対して絶対位置埋め込み
を行っていました。しかし今回のT5では、トークン同士の距離をとらえた相対位置埋め込み
を使用しています。具体的には以下のような処理を行っています。相対位置におけるスコアはスカラー値
となります。
s^{(m)}_{ij} = \frac{query^{(m)}_{i}・key^{(m)T}_{j}}{\sqrt{\frac{D}{M}}} + p^{(m)}_{\lvert i-j\rvert}
この相対位置埋め込みはEncoderのMulti-HeadAttentionとDecoderのMaskedMulti-HeadAttentionに使用されています。
コード
以下が実装したコードになります。
警告
以下は個人で開発したコードです。誤った部分や、冗長な部分がある可能性があります。
Add&Norm
class AddNorm(nn.Module):
def __init__(self, batch_size, seq_length, hidden_size, check_encoder):
super().__init__()
self._setupAddNormModule(batch_size, seq_length, hidden_size, check_encoder)
def _setupAddNormModule(self, batch_size, seq_length, hidden_size, check_encoder):
if check_encoder: seq_length = seq_length[0]
else: seq_length = seq_length[1]
self.layer_norm = nn.LayerNorm((batch_size, seq_length, hidden_size))
def forward(self, tokens, skipped_tokens):
tokens += skipped_tokens
tokens = self.layer_norm(tokens)
return tokens
FeedForward
class FeedForward(nn.Module):
def __init__(self, hidden_size, ffn_hidden_size):
super().__init__()
self._setupFeedForwardModule(hidden_size, ffn_hidden_size)
def _setupFeedForwardModule(self, hidden_size, ffn_hidden_size):
dense1 = nn.Linear(hidden_size, ffn_hidden_size)
relu1 = nn.ReLU()
dense2 = nn.Linear(ffn_hidden_size, hidden_size)
self.feed_foward_module = nn.ModuleList([dense1, relu1, dense2])
def forward(self, tokens):
for module in self.feed_foward_module:
tokens = module(tokens)
return tokens
Multi-HeadAttention(マスク化も可能)
class MultiHeadAttention(nn.Module):
def __init__(self, batch_size, num_heads, seq_length, hidden_size, check_positional_embedding, check_mask):
super().__init__()
self._setupHeadQKV(num_heads, hidden_size)
self.batch_size = batch_size
self.num_heads = num_heads
self.seq_length = seq_length
self.hidden_size = hidden_size
self.check_positional_embedding = check_positional_embedding
self.check_mask = check_mask
self.softmax = nn.Softmax(dim=-1)
#ヘッド毎にquery, key, valueを設定
def _setupHeadQKV(self, num_heads, hidden_size):
query_module = []
key_module = []
value_module = []
head_hidden_size = int(hidden_size / num_heads)
for _ in range(num_heads):
query_module.append(nn.Linear(hidden_size, head_hidden_size))
key_module.append(nn.Linear(hidden_size, head_hidden_size))
value_module.append(nn.Linear(hidden_size, head_hidden_size))
self.query_module = nn.ModuleList(query_module)
self.key_module = nn.ModuleList(key_module)
self.value_module = nn.ModuleList(value_module)
#相対位置埋め込みのスカラー値を計算
def _outputRelativePositionalEmbeddingScalar(self, query, batch_size, seq_length, hidden_size, num_heads):
if self.check_mask: seq_length = self.seq_length[1]
else: seq_length = self.seq_length[0]
embed_Module = []
head_hidden_size = int(hidden_size / num_heads)
position_ids = torch.tensor(list(range(seq_length)), dtype=torch.long).reshape(1, seq_length).expand(batch_size, seq_length)
for id in range(num_heads): embed_Module.append(nn.Embedding(seq_length, head_hidden_size))
self.embed_module = nn.ModuleList(embed_Module)
for id in range(num_heads):
head_query = self.query_module[id](query)
tmp_relative_position_embedding_scalar = (head_query@(self.embed_module[id](position_ids).transpose(1, 2))).reshape(1, batch_size, seq_length, seq_length)
if id == 0: relative_position_embedding_scalar = tmp_relative_position_embedding_scalar
else: relative_position_embedding_scalar = torch.concat([relative_position_embedding_scalar, tmp_relative_position_embedding_scalar], dim=0)
return relative_position_embedding_scalar
#Attentionモジュールの出力
def _outputAttention(self, query, key, value, batch_size, seq_length, hidden_size, num_heads, check_positional_embedding, check_mask, encoder_attention_mask, decoder_attention_mask):
#どのAttentionモジュールかを判断、系列長を設定
if check_positional_embedding:
if check_mask:
seq_length1 = seq_length2 = self.seq_length[1]
else:
seq_length1 = seq_length2 = self.seq_length[0]
else:
seq_length1 = self.seq_length[1]
seq_length2 = self.seq_length[0]
head_hidden_size = int(hidden_size / num_heads)
encoder_attention_mask = encoder_attention_mask.reshape(batch_size, -1, 1).expand(batch_size, seq_length1, seq_length2)
decoder_attention_mask = decoder_attention_mask.reshape(batch_size, 1, -1).expand(batch_size, seq_length1, seq_length2)
#paddingされたベクトルを除外するmap
padding_map = encoder_attention_mask * decoder_attention_mask
#マスク化に使用するmap
mask_map = torch.tensor(np.tril(np.ones((seq_length1, seq_length2))), dtype=torch.long)
#相対位置埋め込みの計算
if check_positional_embedding:relative_position_embedding_scalar = self._outputRelativePositionalEmbeddingScalar(query, batch_size, seq_length, hidden_size, num_heads)
else: relative_position_embedding_scalar = torch.zeros_like(padding_map, dtype=torch.float).expand(num_heads, batch_size, seq_length1, seq_length2)
for id in range(num_heads):
head_query = self.query_module[id](query)
head_key = self.key_module[id](key)
head_value = self.value_module[id](value)
#Attentionの計算
if check_mask: tmp_head_attention = self.softmax(padding_map * (mask_map * (head_query@head_key.transpose(1, 2)) / torch.sqrt(torch.tensor(head_hidden_size)) + relative_position_embedding_scalar[id]))@head_value
else:tmp_head_attention = self.softmax(padding_map * ((head_query@head_key.transpose(1, 2)) / torch.sqrt(torch.tensor(head_hidden_size)) + relative_position_embedding_scalar[id]))@head_value
#各ヘッドの出力を結合
if id == 0: head_attention = tmp_head_attention
else: head_attention = torch.concat([head_attention, tmp_head_attention], dim=-1)
output_attention = head_attention
return output_attention
def forward(self, query, key, value, encoder_attention_mask, decoder_attention_mask):
output_attention = self._outputAttention(query, key, value, self.batch_size, self.seq_length, self.hidden_size, self.num_heads,
self.check_positional_embedding, self.check_mask, encoder_attention_mask, decoder_attention_mask)
return output_attention
EncoderLayer
class EncoderLayer(nn.Module):
#各モジュールを設定
def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
super().__init__()
self.multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=False)
self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)
self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=True)
self.dropout = nn.Dropout(p=0.1)
def forward(self, tokens, encoder_attention_mask):
skip1 = tokens
multi_head_attention = self.multi_head_attention(tokens, tokens, tokens, encoder_attention_mask, encoder_attention_mask)
add_norm1 = self.add_norm1(self.dropout(multi_head_attention), skip1)
skip2 = add_norm1
feed_forward = self.feed_forward(add_norm1)
add_norm2 = self.add_norm2(self.dropout(feed_forward), skip2)
tokens = add_norm2
return tokens
Encoder
class Encoder(nn.Module):
def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
self._setupEncoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
#EncoderLayerを指定回数重ねる
def _setupEncoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
encoder_layer_list = []
for _ in range(num_layer):
encoder_layer = EncoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
encoder_layer_list.append(encoder_layer)
self.encoder_module = nn.ModuleList(encoder_layer_list)
def forward(self, encoder_embedding, encoder_attention_mask):
tokens = encoder_embedding
for encoder_layer in self.encoder_module:
tokens = encoder_layer(tokens, encoder_attention_mask)
output_encoder = tokens
return output_encoder
DecoderLayer
class DecoderLayer(nn.Module):
#各モジュールを設定
def __init__(self, batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size):
super().__init__()
self.masked_multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=True, check_mask=True)
self.add_norm1 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
self.cross_multi_head_attention = MultiHeadAttention(batch_size, num_heads, seq_length, hidden_size, check_positional_embedding=False, check_mask=False)
self.add_norm2 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
self.feed_forward = FeedForward(hidden_size, ffn_hidden_size)
self.add_norm3 = AddNorm(batch_size, seq_length, hidden_size, check_encoder=False)
self.dropout = nn.Dropout(p=0.1)
def forward(self, tokens, output_encoder, encoder_attention_mask, decoder_attention_mask):
skip1 = tokens
masked_multi_head_attention = self.masked_multi_head_attention(tokens, tokens, tokens, decoder_attention_mask, decoder_attention_mask)
add_norm1 = self.add_norm1(self.dropout(masked_multi_head_attention), skip1)
skip2 = add_norm1
cross_multi_head_attention = self.cross_multi_head_attention(tokens, output_encoder, output_encoder, decoder_attention_mask, encoder_attention_mask)
add_norm2 = self.add_norm2(self.dropout(cross_multi_head_attention), skip2)
skip3 = add_norm2
feed_forward = self.feed_forward(tokens)
add_norm3 = self.add_norm3(self.dropout(feed_forward), skip3)
tokens = add_norm3
return tokens
Decoder
class Decoder(nn.Module):
def __init__(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
self._setupDecoderLayer(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
#DecoderLayerを指定回数重ねる
def _setupDecoderLayer(self, batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
decoder_list = []
for _ in range(num_layer):
decoder_layer = DecoderLayer(batch_size, num_heads, seq_length, hidden_size, ffn_hidden_size)
decoder_list.append(decoder_layer)
self.decoder_module = nn.ModuleList(decoder_list)
def forward(self, decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask):
tokens = decoder_embedding
for decoder_layer in self.decoder_module:
tokens = decoder_layer(tokens, output_encoder, encoder_attention_mask, decoder_attention_mask)
output_decoder = tokens
return output_decoder
T5(Encoder-Decoder)
class T5(nn.Module):
def __init__(self, batch_size, num_embedding, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size):
super().__init__()
self.batch_size = batch_size
self.seq_length = seq_length
self.hidden_size = hidden_size
self.num_layer = num_layer
self.num_heads = num_heads
self.ffn_hidden_size = ffn_hidden_size
#Encoder, Decoderの埋め込み層を設定
self.encoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
self.decoder_embedding = nn.Embedding(num_embedding, hidden_size, padding_idx=0)
self.encoder = Encoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
self.decoder = Decoder(batch_size, seq_length, hidden_size, num_layer, num_heads, ffn_hidden_size)
def forward(self, encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask):
encoder_embedding = self.encoder_embedding(encoder_input_ids)
output_encoder = self.encoder(encoder_embedding, encoder_attention_mask)
decoder_embedding = self.decoder_embedding(decoder_input_ids)
decoder_output = self.decoder(decoder_embedding, output_encoder, encoder_attention_mask, decoder_attention_mask)
return decoder_output #下流タスクは設定していない
実行結果
以下がT5を実行した結果になります。
from transformers import T5Tokenizer
from pprint import pprint
impot torch
import torch.nn as nn
import numpy as np
input = ["Translate the following English text into French: Hello, how are you? target:",
"Translate the following English text into French: Good morning, everyone. target",
"Translate the following English text into French: Can you help me with this? target" ]
target = ["Bonjour, comment ça va ?",
"Bonjour à tous.",
"Pouvez-vous m'aider avec ceci ?"]
tokenizer = T5Tokenizer.from_pretrained("t5-small")
encoder_tokenize = tokenizer(input, return_tensors="pt", padding=True)
decoder_tokenize = tokenizer(target, return_tensors="pt", padding=True)
encoder_input_ids = encoder_tokenize.input_ids
encoder_attention_mask = encoder_tokenize.attention_mask
decoder_input_ids = decoder_tokenize.input_ids
decoder_attention_mask = decoder_tokenize.attention_mask
batch_size = encoder_input_ids.size(0)
encoder_seq_length = encoder_input_ids.size(1)
decoder_seq_length = decoder_input_ids.size(1)
max_ids = torch.max(torch.max(encoder_input_ids), torch.max(decoder_input_ids)) #仮の埋め込み辞書数
kwargs = {
"batch_size": batch_size,
"num_embedding": max_ids + 1,
"seq_length": (encoder_seq_length, decoder_seq_length),
"hidden_size": 768,
"num_layer": 12,
"num_heads": 12,
"ffn_hidden_size": 3072,
}
model = T5(**kwargs)
outputs = model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask)
print(model)
print(model(encoder_input_ids, decoder_input_ids, encoder_attention_mask, decoder_attention_mask).shape)
T5(
(encoder_embedding): Embedding(30356, 768, padding_idx=0)
(decoder_embedding): Embedding(30356, 768, padding_idx=0)
(encoder): Encoder(
(encoder_module): ModuleList(
(0-11): 12 x EncoderLayer(
(multi_head_attention): MultiHeadAttention(
(query_module): ModuleList(
(0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
)
(key_module): ModuleList(
(0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
)
(value_module): ModuleList(
(0-11): 12 x Linear(in_features=768, out_features=64, bias=True)
)
(softmax): Softmax(dim=-1)
(embed_module): ModuleList(
(0-11): 12 x Embedding(18, 64)
)
)
(add_norm1): AddNorm(
(layer_norm): LayerNorm((3, 18, 768), eps=1e-05, elementwise_affine=True)
)
(feed_forward): FeedForward(
...
)
)
)
torch.Size([3, 13, 768])
無事T5のモデルを実行、出力することが出来ました。
まとめ
T5の実装を通して、Attention機構の実装の難しさを再度実感しました。相対位置埋め込みやマスク化といった処理を加えることでViTと比べて実装の難易度が何倍にも跳ね上がったように感じました。次は自然言語処理から少し離れてVision-Languageのモデルについて勉強、実装したいと思います。
参考文献
Transformerとは?AI自然言語学習の技術を解説
Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer