2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

制御付きの文章生成をやってみる

Last updated at Posted at 2022-01-29

概要

 本記事では言語生成において,事前に文章の特徴に関する情報を与えた状態で学習することで,文章生成を簡単に制御できるのではないかと考えました.学習時に与えた情報を生成時にも与えながら予測するのである.特徴というのは何でもよいが,今回は肯定文と否定文という特徴で実験してみた.

 例を挙げると「我々は宇宙人」という文の続きを生成する際に,肯定文ラベルなら「我々は宇宙人である」,否定文ラベルなら「我々は宇宙人ではない」という出力になって欲しいということです.

手法

 単純に埋め込んだ特徴を,学習時に埋め込んだ単語列に与えながら学習を行っているだけである.推論(生成)時には,その特徴を与えながら推論を行うだけである.

関連事項

  • 言語モデルは何でもよいですが,今回はVanilla Transformerのエンコーダを使いましょう.
  • トークナイザ―だけこちらの東北大学様のBERTモデルのを使わせて頂きました.

準備

 python = "3.6.8"
 pytorch = "1.6.0"

  • 青空文庫からルールベースで肯定文と否定文らしきものをクロールしました.
  • 今回はデータ収集に関しては省略します.
  • 以下のようなデータです.
# 否定文 neg.txt
それに肝心の当人が気に入らなかった。
先生はこれ以外に何も答えなかった。

# 肯定文 pos.txt
演説の意味はざっとこんなものである。
母は突然はいって来て私の傍に坐った。

コード

  • インポート類
import math
from tqdm import tqdm

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn import TransformerEncoder, TransformerEncoderLayer
  • トークナイザ―を用意
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
PAD_IDX = tokenizer.pad_token_id
BOS_IDX = tokenizer.bos_token_id
CLS_IDX = tokenizer.cls_token_id
EOS_IDX = tokenizer.eos_token_id
VOCAB_SIZE = tokenizer.vocab_size
  • データセットとローダーを用意
class MyDataset(Dataset):
    def __init__(self):
        pos_sentences = [[line.strip(),1] for line in open("./pos.txt","r",encoding="utf-8").readlines()]
        neg_sentences = [[line.strip(),0] for line in open("./neg.txt","r",encoding="utf-8").readlines()]
        self.sentences = pos_sentences+neg_sentences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence,label = self.sentences[idx]
        encoded = tokenizer(sentence,return_tensors='pt',max_length=512,padding='max_length',truncation=True)
        return encoded["input_ids"].squeeze(),torch.tensor(label)

dataset = MyDataset()

train_length = int(len(dataset)*0.9)
val_length = len(dataset) - train_length

train,val = torch.utils.data.random_split(dataset,[train_length,val_length])

train_loader = DataLoader(train,batch_size=8,shuffle=True)
val_loader = DataLoader(val,batch_size=4,shuffle=False)

  • 位置エンコーディング
# Batch First [B,L]->[B,L]
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.transpose(0,1))

    def forward(self, x):
        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)
  • モデルを用意

  • ラベル(否定・肯定文)を系列長にコピーで拡張して([B]->[B*L]),Embeddingレイヤーで埋め込んで各トークンに足し合わせます.


from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerLMModel(nn.Module):
    def __init__(self, ntokens, d_model=512, nhead=8, d_hid=1024, nlayers=6, dropout=0.1):
        super().__init__()

        self.d_model = d_model

        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.encoder = TransformerEncoder(encoder_layer, nlayers)

        self.label_emb = nn.Embedding(2, d_model)
        self.input_emb = nn.Embedding(ntokens, d_model)
        self.output = nn.Linear(d_model, ntokens)

    def generate_square_subsequent_mask(self, sz: int):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def generate_key_padding_mask(self, src, pad_id=PAD_IDX):
        f = torch.full_like(src,False).bool().to()
        t = torch.full_like(src,True).bool()
        return torch.where(src==pad_id,t,f)

    def forward(self, inputs, labels):

        key_mask = self.generate_key_padding_mask(inputs).to(inputs.device)
        sq_mask = self.generate_square_subsequent_mask(inputs.size(1)).to(inputs.device)
        seq_len = inputs.size(1)
        labels_for_inputs = torch.cat([labels.unsqueeze(-1)]*seq_len,dim=-1) #[B]->[B,L]

        inputs = self.input_emb(inputs) * math.sqrt(self.d_model)
        inputs = self.pos_encoder(inputs)
        inputs = inputs + self.label_emb(labels_for_inputs)
        
        out = self.encoder(inputs, mask=sq_mask, src_key_padding_mask=key_mask)
        out = self.output(out)
        return out
  • 余計なPADを削除する関数(BERT用の固定長トークナイズを使っているため)
def pad_cut(ids,pad_token=PAD_IDX):
    length = ids.size(-1)
    for l in range(length):
        if not torch.all(ids[:,-l-1]==pad_token):
            if l==0:
                return ids
            else:
                return ids[:,:-l]
  • 学習ループを書きます.
epoch = 100
for i in (range(epoch)):
    model.train()
    step = 0
    train_epoch_loss = 0
    for batch in tqdm(train_loader):
        step += 1
        inputs,labels = batch

        inputs = pad_cut(inputs.to(device))
        labels = labels.to(device)

        output = model(inputs[:,:-1],labels)

        optimizer.zero_grad()
        loss = criterion(output.reshape(-1, output.size(-1)), inputs[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()

        train_epoch_loss += loss.item()
    train_epoch_loss /= step

    model.eval()
    step = 0
    val_epoch_loss = 0
    for batch in val_loader:
        step += 1
        inputs,labels = batch
        inputs = pad_cut(inputs.to(device))
        labels = labels.to(device)

        with torch.no_grad():
            output = model(inputs[:,:-1],labels)

        loss = criterion(output.reshape(-1, output.size(-1)), inputs[:,1:].reshape(-1))
        val_epoch_loss += loss.item()
    val_epoch_loss /= step

    torch.save(model.state_dict(),"model.bin")

    print("\rSTEP:{}\tTRAINLOSS:{}\tVALLOSS:{}".format(i,train_epoch_loss,val_epoch_loss))

  • Topkサンプリング用のフィルタ
def top_k(logits, thres = 0.9, k = None):
    if k == None:
        k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs
  • 先程学習したモデルから生成を行うスクリプト.
while True:
    text = "犬も歩けば棒に" #これの続きを生成する
    label = torch.tensor([0]).to(device) #否定文:0 肯定文:1
    model.eval()
    max_seq_len = 100
    if text == None:
        x = torch.tensor([[BOS_IDX]]).to(device)
    else:
        x = torch.tensor(tokenizer(text)["input_ids"])[None,:-1].to(device)

    max_seq_len -= x.size(-1)+1

    SAMPLING = True
    for _ in range(max_seq_len):
        with torch.no_grad():
            logits = model(x,label)[:,-1,:]

        probs = F.softmax(logits,dim=-1)
        probs[:,UNK_IDX]*=0.0

        if SAMPLING:
            filtered_probs = top_k(probs, thres = None, k=8)
            pred = torch.multinomial(probs, 1)
        else:
            pred = torch.argmax(probs,dim=-1,keepdim=True)

        x = torch.cat([x,pred],dim=-1)

    print(tokenizer.batch_decode(x))

結果

「犬も歩けば棒に」+否定文ラベル で生成

犬も歩けば棒になっているわけでもなかった。
犬も歩けば棒になったが、これはただ一言も出来ない。
犬も歩けば棒にはなりません。

「犬も歩けば棒に」+肯定文ラベル で生成

犬も歩けば棒になった。
犬も歩けば棒に空の身体だった。
犬も歩けば棒になってしまった。

「我々は宇宙人」+否定文ラベル で生成

我々は宇宙人の姿が見えません。
我々は宇宙人がの目では無かった。
我々は宇宙人でも恐れるような従来ではなかった。

「我々は宇宙人」+肯定文ラベル で生成

我々は宇宙人であった。
我々は宇宙人の一味も上京であった。
我々は宇宙人の若者が好きなのだから今日もまた御決心してしまった。

「ポリグリップを買った」+否定文ラベル で生成

ポリグリップを買った覚えは無かった。
ポリグリップを買った自分のは無理で姿を出さなかった。

「ポリグリップを買った」+肯定文ラベル で生成

ポリグリップを買った。
ポリグリップを買った宗助は、細君に相手の有効を辛抱にこの一説を掻いた理屈である。

「緊張すると声が光彦に」+否定文ラベル で生成

緊張すると声が光彦に指されていることは知らない。
緊張すると声が光彦に間に救われたとも云えなかった。
緊張すると声が光彦に曲がっているような感じがした。

「緊張すると声が光彦に」+肯定文ラベル で生成

緊張すると声が光彦に並んでいった。
緊張すると声が光彦に流れて、落ちたのは神領の像でした。

結論

  • 結果に示したように先に入力した文の特徴(否定・肯定)に合わせた文章生成が行われている.

  • データが1~2MB程度の少なさであり,自然な文章を生成できていない.不自然である.

  • 否定・肯定以外の制御を行っていないため,それ以外の部分で使用者の想像の斜め上を行く文を生成してしまう.

  • 結果の一つに否定指示にも関わらず「ような感じがした。」とあり,ごく稀に制御を外れている.

まとめ

  • 否定文か肯定文かの情報を組み込んで学習を行うことで,生成時にどちらかの特徴に沿った文章生成を行った.

最後に

 誤っている部分等ございましたら,コメント等で優しく指摘して頂けると嬉しいです.(気付かなければ申し訳ありません)

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?