やりたいこと
与えられた入力文に誤字・脱字がないかチェックを行いたい。
例えば、「絵がとても上手い」は問題ないが、「絵がとても美味い」は誤字として判定したい。
この判定をBERTのMasked Language Modelを利用して行ってみる。
方式
BERTのMasked Language Modelは、マスクされた単語の予測を行うことができる。
例えば、['[CLS]', '[MASK]', 'が', 'とても', '上手', '##い', '[SEP]']
というtorkenizeされたリストの[MASK]
に入る単語を予測すると、['歌', '野球', 'ゴルフ', 'テニス', 'サッカー']などが得られる。
この予測結果を利用して、マスクされた単語が予測結果のTop Kに含まれるかを用いて誤字・脱字がないか単語毎に判定を行う。これを単語ごとに繰り返し、各単語の判定結果のANDを最終結果として採用する。
BERTが上手く学習ができているのであれば、['[CLS]', '絵', 'が', 'とても', '[MASK]', '##い', '[SEP]']
に対して予測を行う際に、上手
のスコアは高くなるが、美味
のスコアは引くくなると思われるので、誤字として判定される。
利用環境
torch==1.3.1
transformers==2.5.0
コード
判定を行うクラス
import logging
import torch
from transformers import BertForMaskedLM, BertJapaneseTokenizer
logger = logging.getLogger(__name__)
class BertProofreader:
def __init__(self, pretrained_model: str, cache_dir: str = None):
# Load pre-trained model tokenizer (vocabulary)
self.tokenizer = BertJapaneseTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
# Load pre-trained model (weights)
self.model = BertForMaskedLM.from_pretrained(pretrained_model, cache_dir=cache_dir)
self.model.to('cuda')
self.model.eval()
def mask_prediction(self, sentence: str) -> torch.Tensor:
# 特殊Tokenの追加
sentence = f'[CLS]{sentence}[SEP]'
tokenized_text = self.tokenizer.tokenize(sentence)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens], device='cuda')
# [MASK]に対応するindexを取得
mask_index = self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
# 1単語ずつ[MASK]に置き換えたTensorを作る
repeat_num = tokens_tensor.shape[1] - 2
tokens_tensor = tokens_tensor.repeat(repeat_num, 1)
for i in range(repeat_num):
tokens_tensor[i, i + 1] = mask_index
# Predict all tokens
with torch.no_grad():
outputs = self.model(tokens_tensor, token_type_ids=None)
predictions = outputs[0]
return tokenized_text, predictions
def check_topk(self, sentence: str, topk: int = 10):
"""
[MASK]に対して予測された単語のTop Kに元の単語が含まれていればTrueと判定
"""
tokens, predictions = self.mask_prediction(sentence)
pred_sort = torch.argsort(predictions, dim=2, descending=True)
pred_top_k = pred_sort[:, :, :topk] # 上位Xのindex取得
judges = []
for i in range(len(tokens) - 2):
pred_top_k_word = self.tokenizer.convert_ids_to_tokens(pred_top_k[i][i + 1])
judges.append(tokens[i + 1] in pred_top_k_word)
logger.info(f'{tokens[i + 1]}: {judges[-1]}')
logger.debug(f'top k word={pred_top_k_word}')
return all(judges)
def check_threshold(self, sentence: str, threshold: float = 0.01):
"""
[MASK]に対して予測された単語のスコアが閾値以上の単語群に、元の単語が含まれていればTrueと判定
"""
tokens, predictions = self.mask_prediction(sentence)
predictions = predictions.softmax(dim=2)
judges = []
for i in range(len(tokens) - 2):
indices = (predictions[i][i + 1] >= threshold).nonzero()
pred_top_word = self.tokenizer.convert_ids_to_tokens(indices)
judges.append(tokens[i + 1] in pred_top_word)
logger.info(f'{tokens[i + 1]}: {judges[-1]}')
return all(judges)
校正の実行
クラスのインスタンス作成。
BERTのpretrainモデルは東北大乾研のbert-base-japanese-whole-word-masking
を使用。
import logging
from models.bert_proofreader import BertProofreader
logging.basicConfig(level=logging.INFO)
PRETRAINED_MODEL = 'bert-base-japanese-whole-word-masking'
proofreader = BertProofreader(PRETRAINED_MODEL)
TOP 5で判定
「絵」と「##い」がTop5に入らない。
proofreader.check_topk('絵がとても上手い', topk=5)
INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: False
False
loggerをDEBUGレベルに変更して上位5件の単語を確認してみる。
Wikipediaで学習したモデルでは、「歌がとても上手。」が一番もっともらしいよう。
logging.basicConfig(level=logging.DEBUG)
proofreader.check_topk('絵がとても上手い', topk=5)
INFO:models.bert_proofreader:絵: False
DEBUG:models.bert_proofreader:top k word=['歌', '野球', 'ゴルフ', 'テニス', 'サッカー']
INFO:models.bert_proofreader:が: True
DEBUG:models.bert_proofreader:top k word=['が', 'は', 'も', 'に', 'を']
INFO:models.bert_proofreader:とても: True
DEBUG:models.bert_proofreader:top k word=['とても', '一番', 'いちばん', 'かなり', '本当に']
INFO:models.bert_proofreader:上手: True
DEBUG:models.bert_proofreader:top k word=['上手', 'うま', '下手', '可愛', '得意']
INFO:models.bert_proofreader:##い: False
DEBUG:models.bert_proofreader:top k word=['。', 'です', '!!', '!', 'な']
False
TOP 100で判定。
TOP 100だと「絵がとても上手い」でもOKになる。
proofreader.check_topk('絵がとても上手い', topk=100)
INFO:models.bert_proofreader:絵: True
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: True
True
TOP 100で誤字ありの文の判定。
「絵がとても美味い」はTOP100でもNGになる。
proofreader.check_topk('絵がとても美味い', topk=100)
INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:美味: True
INFO:models.bert_proofreader:##い: True
False
閾値を使って判定。閾値を0.01に設定
「絵がとても上手い」でもNGになる。
proofreader.check_threshold('絵がとても上手い', threshold=0.01)
INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: False
False
閾値をどこまで下げれば「絵がとても上手い」がOKになるかと思って下げたが、かなり閾値が下になった。
proofreader.check_threshold('絵がとても上手い', threshold=0.0000001)
INFO:models.bert_proofreader:絵: True
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: True
True
所感
- Wikipediaでpretrainしたモデルで文章校正するのは難しそう
- Masked Language Modelをドメインデータで追加pretrainした方が上手くいきそう