#概要
PyTorch様公式のTransformer実装が複雑怪奇だったため,より簡素な実装をしていきます.公式とは逆にbatch firstな点を注意してください.
※今回は自己回帰型のエンコーダのみの実装になります.
#準備
python = "3.6.8"
pytorch = "1.6.0"
#ソースコード
- インポート類
import copy
import math
import numpy as np
import torch
torch.manual_seed(41)
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
- Softmaxを使ったMultiHeadのSelfAttention
Input: QKV [batch,head,length,head_dim], mask [batch,len,len]
Output: X [batch,head,length,head_dim]
class SoftmaxAttention(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.head_dim = head_dim
def forward(self, Q, K, V, mask=None):
logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)
if mask!=None:
logit = logit + mask[:,None,:,:]
attention_weight = F.softmax(logit, dim=-1)
X = torch.einsum("bhlm,bhmd->bhld",attention_weight,V)
return X
- MultiHead化を含むSelfAttention層全体
Input: X [batch,length,dim], mask [batch,len,len]
Output: X [batch,length,dim]
class SelfAttention(nn.Module):
def __init__(self, dim, head_dim, num_head):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.num_head = num_head
assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")
self.W_q = nn.Linear(self.dim, self.dim)
self.W_k = nn.Linear(self.dim, self.dim)
self.W_v = nn.Linear(self.dim, self.dim)
self.attn = SoftmaxAttention(head_dim)
def forward(self, X, mask):
Q = self.split_heads(self.W_q(X))
K = self.split_heads(self.W_k(X))
V = self.split_heads(self.W_v(X))
attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
attn_out = self.combine_heads(attn_out)
return attn_out
#[batch,head,len,head_dim]->[batch,len,dim]
def combine_heads(self, X):
X = X.transpose(1, 2)
X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
return X
#[batch,len,dim]->[batch,head,len,head_dim]
def split_heads(self, X):
X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
X = X.transpose(1, 2)
return X
- 位置エンコーディング(公式の実装はbatch firstではない事に注意)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.transpose(0,1))
#[batch,len,dim]->[batch,len,dim]
def forward(self, x):
x = x + self.pe[:,:x.size(1),:]
return self.dropout(x)
- Transformerエンコーダ層
Input: x [batch,length,dim], mask [batch,len,len]
Output: x [batch,length,dim]
class TransformerEncoderLayer(nn.Module):
def __init__(self, attn_layer, ff_layer, norm_layer, drop_layer):
super().__init__()
self.attn_layer = attn_layer
self.ff_layer = ff_layer
self.norm_layer = norm_layer
self.drop_layer = drop_layer
def forward(self, x, mask):
x = self.drop_layer(self.attn_layer(x,mask)) + x
x = self.norm_layer(x)
x = self.drop_layer(self.ff_layer(x)) + x
x = self.norm_layer(x)
return x
- Transformerエンコーダ全体
コメントアウト部分はキーパディング部分に対するmask処理です(あまり効果が無いので入れていません)
Input: x [batch,length]
Output: logits [batch,length,vocab_size]
from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm
class TransformerEncoder(nn.Module):
def __init__(self, num_layer, dim, ff_dim, head_dim, num_head, vocab_size=None, drop_p=0.1):
super().__init__()
self.vocab_size = vocab_size
self.emb = nn.Embedding(vocab_size,dim) if vocab_size!=None else nn.Identity()
self.pos_encoder = PositionalEncoding(dim)
self.output = nn.Linear(dim,vocab_size) if vocab_size!=None else nn.Identity()
self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(dim, head_dim, num_head),\
nn.Sequential(nn.Linear(dim,ff_dim),nn.ReLU(),nn.Linear(ff_dim,dim)),\
LayerNorm(dim, eps=1e-5),\
nn.Dropout(drop_p)) for i in range(num_layer)])
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
# def generate_key_padding_mask(self, src, pad_id=0):
# f = torch.full_like(src,False).bool().to()
# t = torch.full_like(src,True).bool()
# return torch.where(src==pad_id,t,f)
def forward(self, x, key_mask=None, sq_mask=None):
mask = torch.zeros(x.size(0),x.size(1),x.size(1)).bool().to(x.device)
# if self.vocab_size != None:
# key_mask = self.generate_key_padding_mask(x).to(x.device)
sq_mask = self.generate_square_subsequent_mask(x.size(-1)).to(x.device)
# if key_mask != None:
# mask = mask.bool().to(x.device) + torch.cat([key_mask.unsqueeze(1)]*x.size(1),dim=1).bool() + torch.cat([key_mask.unsqueeze(2)]*x.size(1),dim=2).bool()
# mask = torch.zeros_like(mask).masked_fill_(mask,float("-inf")).to(x.device)
if sq_mask!=None:
mask = mask + sq_mask[None,:,:]
mask = mask.float().to(x.device)
x = self.emb(x)
x = self.pos_encoder(x)
for layer in self.encoders:
x = layer(x,mask)
return self.output(x)
- モデルの宣言
model = TransformerEncoder(vocab_size=VOCAB_SIZE,num_layer=6,dim=512,ff_dim=1024,head_dim=64,num_head=8)
- 試しにテキストを自己回帰で学習してみましょう.
pip install transformers
from torch.utils.data import DataLoader, Dataset
TRAIN_BATCH = 40
VAL_BATCH = 20
LEARNING_RATE = 1e-4
with open("text.txt","r",encoding="utf-8") as r: #適当なテキストデータ
lines = [line.strip() for line in r.readlines()]
import transformers
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char")
# tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
PAD_IDX = tokenizer.pad_token_id
CLS_IDX = tokenizer.cls_token_id
EOS_IDX = tokenizer.eos_token_id
VOCAB_SIZE = tokenizer.vocab_size
class MyDataset(Dataset):
def __init__(self,lines,_tokenizer):
self.text = lines
self.tokenizer = _tokenizer
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
text = self.text[idx]
encode = self.tokenizer(text)
return torch.tensor(encode["input_ids"])
def collate_fn(batch):
x = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=PAD_IDX)
return x
dataset = MyDataset(lines,tokenizer)
train_length = int(len(dataset)*0.9)
val_length = len(dataset) - train_length
train,val = torch.utils.data.random_split(dataset,[train_length,val_length])
train_loader = DataLoader(train,batch_size=TRAIN_BATCH,shuffle=True,collate_fn=collate_fn)
val_loader = DataLoader(val,batch_size=VAL_BATCH,shuffle=False,collate_fn=collate_fn)
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def cut_tensor(x,max_len=512):
if x.size(-1)>max_len:
return x[:,:max_len]
else: return x
epoch = 20
for i in range(epoch):
model.train()
step = 0
train_epoch_loss = 0
for batch in train_loader:
step += 1
src = batch
# src = cut_tensor(src,max_len=SEQUENCE_LENGTH+1)#max_length_cut
src = src.to(device)
output = model(src[:,:-1])
optim.zero_grad()
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
loss.backward()
optim.step()
train_epoch_loss += loss.item()
train_epoch_loss /= step
model.eval()
step = 0
val_epoch_loss = 0
for batch in val_loader:
step += 1
src = batch
# src = cut_tensor(src,max_len=SEQUENCE_LENGTH+1)#max_length_cut
src = src.to(device)
with torch.no_grad():
output = model(src[:,:-1])
loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
val_epoch_loss += loss.item()
val_epoch_loss /= step
print("\rSTEP:{}\tTRAINLOSS:{}\tVALLOSS:{}".format(i,train_epoch_loss,val_epoch_loss))
トークナイザ―は東北大学様のBERTの事前学習モデルよりお借りしました.
PyTorchの公式実装と同じか,より低い損失値になったので良い感じです.
#まとめ
- 公式実装がややこしかったので,簡潔な自己回帰のTransformerエンコーダのコードを書きました.
#備考
- layer(x, mask) -> layer(x) とすれば自己回帰ではなくなります.
- キーパディングのmask処理を入れるには,TransformerEncoderのコメントアウトを全て外してください.(GPU処理じゃないとエラーが出る場合があるそうです)
#最後に
間違っている点などございましたら,コメント等で優しく指摘して頂けると助かります.(気付かなければ申し訳ありません)