2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

簡潔にSelfAttentionとTransformerエンコーダーを書いてみる

Last updated at Posted at 2022-02-01

#概要
 PyTorch様公式のTransformer実装が複雑怪奇だったため,より簡素な実装をしていきます.公式とは逆にbatch firstな点を注意してください.

 ※今回は自己回帰型のエンコーダのみの実装になります.

#準備
 python = "3.6.8"
 pytorch = "1.6.0"

#ソースコード

  • インポート類
import copy
import math
import numpy as np

import torch
torch.manual_seed(41)
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
  • Softmaxを使ったMultiHeadのSelfAttention

Input: QKV [batch,head,length,head_dim], mask [batch,len,len]
Output: X [batch,head,length,head_dim]

class SoftmaxAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.head_dim = head_dim

    def forward(self, Q, K, V, mask=None):
        logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)

        if mask!=None:
            logit = logit + mask[:,None,:,:]
        
        attention_weight = F.softmax(logit, dim=-1)
        X = torch.einsum("bhlm,bhmd->bhld",attention_weight,V)

        return X
  • MultiHead化を含むSelfAttention層全体

Input: X [batch,length,dim], mask [batch,len,len]
Output: X [batch,length,dim]

class SelfAttention(nn.Module):
    def __init__(self, dim, head_dim, num_head):
        super().__init__()

        self.dim = dim
        self.head_dim = head_dim
        self.num_head = num_head

        assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")

        self.W_q = nn.Linear(self.dim, self.dim)
        self.W_k = nn.Linear(self.dim, self.dim)
        self.W_v = nn.Linear(self.dim, self.dim)

        self.attn = SoftmaxAttention(head_dim)

    def forward(self, X, mask):
        Q = self.split_heads(self.W_q(X))
        K = self.split_heads(self.W_k(X))
        V = self.split_heads(self.W_v(X))

        attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
        attn_out = self.combine_heads(attn_out)
        return attn_out
    
    #[batch,head,len,head_dim]->[batch,len,dim]
    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
        return X

    #[batch,len,dim]->[batch,head,len,head_dim]
    def split_heads(self, X):
        X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
        X = X.transpose(1, 2)
        return X
  • 位置エンコーディング(公式の実装はbatch firstではない事に注意)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.transpose(0,1))

    #[batch,len,dim]->[batch,len,dim]
    def forward(self, x):
        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)
  • Transformerエンコーダ層

Input: x [batch,length,dim], mask [batch,len,len]
Output: x [batch,length,dim]

class TransformerEncoderLayer(nn.Module):
    def __init__(self, attn_layer, ff_layer, norm_layer, drop_layer):
        super().__init__()
        self.attn_layer = attn_layer
        self.ff_layer = ff_layer
        self.norm_layer = norm_layer
        self.drop_layer = drop_layer
    
    def forward(self, x, mask):
        x = self.drop_layer(self.attn_layer(x,mask)) + x
        x = self.norm_layer(x)

        x = self.drop_layer(self.ff_layer(x)) + x
        x = self.norm_layer(x)
        
        return x
  • Transformerエンコーダ全体

 コメントアウト部分はキーパディング部分に対するmask処理です(あまり効果が無いので入れていません)

Input: x [batch,length]
Output: logits [batch,length,vocab_size]

from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm

class TransformerEncoder(nn.Module):
    def __init__(self, num_layer, dim, ff_dim, head_dim, num_head, vocab_size=None, drop_p=0.1):
        super().__init__()

        self.vocab_size = vocab_size
        
        self.emb = nn.Embedding(vocab_size,dim) if vocab_size!=None else nn.Identity()
        self.pos_encoder = PositionalEncoding(dim)
        self.output = nn.Linear(dim,vocab_size) if vocab_size!=None else nn.Identity()

        self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(dim, head_dim, num_head),\
                            nn.Sequential(nn.Linear(dim,ff_dim),nn.ReLU(),nn.Linear(ff_dim,dim)),\
                                LayerNorm(dim, eps=1e-5),\
                                    nn.Dropout(drop_p)) for i in range(num_layer)])

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    # def generate_key_padding_mask(self, src, pad_id=0):
    #     f = torch.full_like(src,False).bool().to()
    #     t = torch.full_like(src,True).bool()
    #     return torch.where(src==pad_id,t,f)

    def forward(self, x, key_mask=None, sq_mask=None):
        mask = torch.zeros(x.size(0),x.size(1),x.size(1)).bool().to(x.device)

        # if self.vocab_size != None:
            # key_mask = self.generate_key_padding_mask(x).to(x.device)

        sq_mask = self.generate_square_subsequent_mask(x.size(-1)).to(x.device)

        # if key_mask != None:
        #     mask = mask.bool().to(x.device) + torch.cat([key_mask.unsqueeze(1)]*x.size(1),dim=1).bool() + torch.cat([key_mask.unsqueeze(2)]*x.size(1),dim=2).bool()
        #     mask = torch.zeros_like(mask).masked_fill_(mask,float("-inf")).to(x.device)

        if sq_mask!=None:
            mask = mask + sq_mask[None,:,:]

        mask = mask.float().to(x.device)

        x = self.emb(x)
        x = self.pos_encoder(x)

        for layer in self.encoders:
            x = layer(x,mask)

        return self.output(x)
  • モデルの宣言
model = TransformerEncoder(vocab_size=VOCAB_SIZE,num_layer=6,dim=512,ff_dim=1024,head_dim=64,num_head=8)
  • 試しにテキストを自己回帰で学習してみましょう.

pip install transformers

from torch.utils.data import DataLoader, Dataset

TRAIN_BATCH = 40
VAL_BATCH = 20
LEARNING_RATE = 1e-4

with open("text.txt","r",encoding="utf-8") as r:  #適当なテキストデータ
    lines  = [line.strip() for line in r.readlines()]

import transformers
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char")
# tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")

PAD_IDX = tokenizer.pad_token_id
CLS_IDX = tokenizer.cls_token_id
EOS_IDX = tokenizer.eos_token_id
VOCAB_SIZE = tokenizer.vocab_size

class MyDataset(Dataset):
    def __init__(self,lines,_tokenizer):
        self.text = lines
        self.tokenizer = _tokenizer

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        text = self.text[idx]
        encode = self.tokenizer(text)
        return torch.tensor(encode["input_ids"])

def collate_fn(batch):
    x = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=PAD_IDX)
    return x

dataset = MyDataset(lines,tokenizer)
train_length = int(len(dataset)*0.9)
val_length = len(dataset) - train_length
train,val = torch.utils.data.random_split(dataset,[train_length,val_length])

train_loader = DataLoader(train,batch_size=TRAIN_BATCH,shuffle=True,collate_fn=collate_fn)
val_loader = DataLoader(val,batch_size=VAL_BATCH,shuffle=False,collate_fn=collate_fn)

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def cut_tensor(x,max_len=512):
    if x.size(-1)>max_len:
        return x[:,:max_len]
    else: return x

epoch = 20
for i in range(epoch):
    model.train()
    step = 0
    train_epoch_loss = 0
    for batch in train_loader:
        step += 1
        src = batch
        # src = cut_tensor(src,max_len=SEQUENCE_LENGTH+1)#max_length_cut
        src = src.to(device)
        output = model(src[:,:-1])
        optim.zero_grad()
        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
        loss.backward()
        optim.step()
        train_epoch_loss += loss.item()
    train_epoch_loss /= step

    model.eval()
    step = 0
    val_epoch_loss = 0
    for batch in val_loader:
        step += 1
        src = batch
        # src = cut_tensor(src,max_len=SEQUENCE_LENGTH+1)#max_length_cut
        src = src.to(device)
        with torch.no_grad():
            output = model(src[:,:-1])
        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
        val_epoch_loss += loss.item()
    val_epoch_loss /= step

    print("\rSTEP:{}\tTRAINLOSS:{}\tVALLOSS:{}".format(i,train_epoch_loss,val_epoch_loss))

 トークナイザ―は東北大学様のBERTの事前学習モデルよりお借りしました.

 PyTorchの公式実装と同じか,より低い損失値になったので良い感じです.

#まとめ

  • 公式実装がややこしかったので,簡潔な自己回帰のTransformerエンコーダのコードを書きました.

#備考

  • layer(x, mask) -> layer(x) とすれば自己回帰ではなくなります.
  • キーパディングのmask処理を入れるには,TransformerEncoderのコメントアウトを全て外してください.(GPU処理じゃないとエラーが出る場合があるそうです)

#最後に
 間違っている点などございましたら,コメント等で優しく指摘して頂けると助かります.(気付かなければ申し訳ありません)

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?