1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PoolFormerで自己回帰をしておく

Last updated at Posted at 2022-02-14

概要

最近はViT方面でTransformerのSelf-Attention部分が一般化され,トークン方向の情報を混ぜるTokenMixerとか呼ばれています.

SOTAではありませんが,TokenMixerにAvePooling2Dを使った画像認識モデルにPoolFormerというものがあります.

取り合えずこのモデルを言語用に実装し直して,自己回帰型に拡張したAvePooling1DをMixerとして文章の学習をしておきます.

準備

 python = "3.6.8"
 pytorch = "1.6.0"

夏目漱石『吾輩は猫である』を"text.txt"として保存しています.

コード

  • インポート類
import random
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(41)
  • テキスト読み込み
with open("text.txt","r",encoding="utf-8") as r:
    lines  = [line.strip() for line in r.readlines()]
  • バッチサイズ
TRAIN_BATCH = 40
VAL_BATCH = 20
  • データセット&データローダ
  • トークナイザは東北大学様より
import transformers
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char")
# tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
PAD_IDX = tokenizer.pad_token_id
BOS_IDX = tokenizer.cls_token_id
EOS_IDX = tokenizer.sep_token_id
VOCAB_SIZE = tokenizer.vocab_size

class MyDataset(Dataset):
    def __init__(self,lines):
        self.text = lines

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        text = self.text[idx]
        encode = tokenizer(text)
        return torch.tensor(encode["input_ids"])

def collate_fn(batch):
    x = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=PAD_IDX)
    return x

dataset = MyDataset(lines)

train_length = int(len(dataset)*0.9)
val_length = len(dataset) - train_length

train,val = torch.utils.data.random_split(dataset,[train_length,val_length],generator=torch.Generator().manual_seed(40))

train_loader = DataLoader(train,batch_size=TRAIN_BATCH,shuffle=True,collate_fn=collate_fn)
val_loader = DataLoader(val,batch_size=VAL_BATCH,shuffle=False,collate_fn=collate_fn)

  • 位置エンコーディング
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.transpose(0,1))

    def forward(self, x):

        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)
  • ここから本題,モデルの実装をします.
  • Poolingの前にpool_size-1のpaddingを頭にだけ引っ付けます.(これでCausal Modelになります)
import math
import torch
import torch.nn as nn
from torch.nn.modules.container import ModuleList

from timm.models.layers import DropPath, trunc_normal_

class Pooling(nn.Module):
    def __init__(self, pool_size=3):
        super().__init__()
        self.pool_size=pool_size
        self.pool = nn.AvgPool1d(
            pool_size, stride=1, count_include_pad=False)

    def forward(self, x):
        x = x.transpose(1,2)
        x = self.pool(torch.cat([torch.zeros(x.size(0),x.size(1),self.pool_size-1).to(x.device),x],dim=-1)) - x
        return x.transpose(1,2)

class GroupNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm_layer = nn.GroupNorm(1,dim)
    
    def forward(self,x):
        x = x.transpose(1,2)
        x = self.norm_layer(x)
        return x.transpose(1,2)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, 
                 out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class PoolFormerBlock(nn.Module):
    def __init__(self, dim, pool_size=3, mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=GroupNorm, 
                 drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.token_mixer = Pooling(pool_size=pool_size)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                       act_layer=act_layer, drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.token_mixer(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class PoolFormer(nn.Module):
    def __init__(self, layer_num, dim, pool_size=3, mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=GroupNorm, 
                 drop=0., drop_path=0.):
        super().__init__()
        self.blocks = ModuleList([PoolFormerBlock(dim, pool_size, mlp_ratio, \
                            act_layer, norm_layer, drop, drop_path) for i in range(layer_num)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

class PoolFormerLM(nn.Module):
    def __init__(self, vocab_size, layer_num, dim, pool_size=3, mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=GroupNorm, 
                 drop=0., drop_path=0.):
        
        super().__init__()
        self.poolformer = PoolFormer(layer_num, dim, pool_size, mlp_ratio, 
                                    act_layer, norm_layer,drop, drop_path)

        self.embedding = nn.Embedding(vocab_size,dim)
        self.output = nn.Linear(dim,vocab_size)
        self.pos_encode = PositionalEncoding(dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encode(x)
        x = self.poolformer(x)
        x = self.output(x)
        return x

model = PoolFormerLM(vocab_size=VOCAB_SIZE,layer_num=6,dim=512,drop=0.3,drop_path=0.3)
  • 性能が良くないのでPoolingの中身をConv1Dにしたバージョン(引数にdimを追加)
class Pooling(nn.Module):
    def __init__(self, dim, pool_size=3):
        super().__init__()
        self.pool_size=pool_size
        self.pool = nn.Conv1d(
            dim,dim,kernel_size=pool_size, stride=1)

    def forward(self, x):
        x = x.transpose(1,2)
        x = self.pool(torch.cat([torch.zeros(x.size(0),x.size(1),self.pool_size-1).to(x.device),x],dim=-1)) - x
        return x.transpose(1,2)
  • 訓練ループも一応書いておきます.
optim = torch.optim.Adam(model.parameters())
device = torch.device("cuda")
model = model.to(device)

from tqdm import tqdm
epoch = 100
for i in tqdm(range(epoch)):
    model.train()
    step = 0
    train_epoch_loss = 0
    for batch in (train_loader):
        step += 1
        src = batch
        src = src.to(device)
        optim.zero_grad()
        output = model(src[:,:-1])
        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
        loss.backward()
        optim.step()
        train_epoch_loss += loss.item()
    train_epoch_loss /= step

    model.eval()
    step = 0
    val_epoch_loss = 0
    for batch in val_loader:
        step += 1
        src = batch
        src = src.to(device)
        with torch.no_grad():
            output = model(src[:,:-1])
        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), src[:,1:].reshape(-1), ignore_index = PAD_IDX)
        val_epoch_loss += loss.item()
    val_epoch_loss /= step

    print("\rSTEP:{}\tTRAINLOSS:{}\tVALLOSS:{}".format(i,train_epoch_loss,val_epoch_loss))
  • 雑に組んだ生成も書いておきます.
def top_k(logits, thres = 0.9, k = None):
    if k == None:
        k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 可変長Generate
text = "猫は"
model.eval()
max_seq_len = 128
seq_len = 128
if text == None:
    start_tokens = torch.tensor([BOS_IDX])[None,:].to(device)
else:
    start_tokens = torch.tensor(tokenizer(text)["input_ids"])[None,:-1].to(device)
b, t = start_tokens.shape
out = start_tokens

SAMPLING = True

for _ in range(seq_len):
    x = out[:, -max_seq_len:]
    with torch.no_grad():
        logits = model(x)[:, -1, :]
    if SAMPLING:
        temperature = 1.
        filtered_logits = top_k(logits, thres = None, k=10)
        probs = F.softmax(filtered_logits / temperature, dim=-1)
        pred = torch.multinomial(probs, 1)
    else:
        pred = logits.argmax().unsqueeze(0).unsqueeze(0)
    out = torch.cat((out, pred), dim=-1)
    if pred.item() == EOS_IDX:
        break
out = out.squeeze(0)
print(tokenizer.decode(out.tolist()).replace(" ",""))

生成してみる

「猫はかについものとなでさばこりかけて鏡と見いのだから、は三代だと一概でも乱って来てあまえ、戸とかけると、この体骨としたのは、ただそっその癖が、今日一日にかなくなった。」

結論

  • 損失から見て性能はイマイチである.Attentionを使わない言語モデルなら自己回帰含めてgMLPを使った方がマシな気がするが,PoolingとLinearだけで言語モデルを組める点は面白い.

  • このモデルのSST-2の精度は事前学習無しで86%程度だった(事前学習ありのBERTが91~93%くらい)

  • 大規模なMLMとして事前学習した後の結果が見たい.

まとめ

  • PoolFormerで自己回帰の言語モデルを一応試しておいた.

最後に

  • 間違っている箇所があればコメントなどで優しく指摘して頂けると助かります.
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?