27
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

BERTのMasked Language Modelを利用して文の校正を行う

Posted at

やりたいこと

与えられた入力文に誤字・脱字がないかチェックを行いたい。
例えば、「絵がとても上手い」は問題ないが、「絵がとても美味い」は誤字として判定したい。

この判定をBERTのMasked Language Modelを利用して行ってみる。

方式

BERTのMasked Language Modelは、マスクされた単語の予測を行うことができる。
例えば、['[CLS]', '[MASK]', 'が', 'とても', '上手', '##い', '[SEP]']というtorkenizeされたリストの[MASK]に入る単語を予測すると、['歌', '野球', 'ゴルフ', 'テニス', 'サッカー']などが得られる。

image.png

この予測結果を利用して、マスクされた単語が予測結果のTop Kに含まれるかを用いて誤字・脱字がないか単語毎に判定を行う。これを単語ごとに繰り返し、各単語の判定結果のANDを最終結果として採用する。
BERTが上手く学習ができているのであれば、['[CLS]', '絵', 'が', 'とても', '[MASK]', '##い', '[SEP]']に対して予測を行う際に、上手のスコアは高くなるが、美味のスコアは引くくなると思われるので、誤字として判定される。

image.png

利用環境

torch==1.3.1
transformers==2.5.0

コード

判定を行うクラス

import logging

import torch
from transformers import BertForMaskedLM, BertJapaneseTokenizer

logger = logging.getLogger(__name__)


class BertProofreader:
    def __init__(self, pretrained_model: str, cache_dir: str = None):

        # Load pre-trained model tokenizer (vocabulary)
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)

        # Load pre-trained model (weights)
        self.model = BertForMaskedLM.from_pretrained(pretrained_model, cache_dir=cache_dir)
        self.model.to('cuda')

        self.model.eval()

    def mask_prediction(self, sentence: str) -> torch.Tensor:
        # 特殊Tokenの追加
        sentence = f'[CLS]{sentence}[SEP]'

        tokenized_text = self.tokenizer.tokenize(sentence)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens], device='cuda')

        # [MASK]に対応するindexを取得
        mask_index = self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]

        # 1単語ずつ[MASK]に置き換えたTensorを作る
        repeat_num = tokens_tensor.shape[1] - 2
        tokens_tensor = tokens_tensor.repeat(repeat_num, 1)
        for i in range(repeat_num):
            tokens_tensor[i, i + 1] = mask_index

        # Predict all tokens
        with torch.no_grad():
            outputs = self.model(tokens_tensor, token_type_ids=None)
            predictions = outputs[0]

        return tokenized_text, predictions

    def check_topk(self, sentence: str, topk: int = 10):
        """
        [MASK]に対して予測された単語のTop Kに元の単語が含まれていればTrueと判定
        """

        tokens, predictions = self.mask_prediction(sentence)

        pred_sort = torch.argsort(predictions, dim=2, descending=True)
        pred_top_k = pred_sort[:, :, :topk]  # 上位Xのindex取得

        judges = []
        for i in range(len(tokens) - 2):
            pred_top_k_word = self.tokenizer.convert_ids_to_tokens(pred_top_k[i][i + 1])
            judges.append(tokens[i + 1] in pred_top_k_word)
            logger.info(f'{tokens[i + 1]}: {judges[-1]}')
            logger.debug(f'top k word={pred_top_k_word}')

        return all(judges)

    def check_threshold(self, sentence: str, threshold: float = 0.01):
        """
        [MASK]に対して予測された単語のスコアが閾値以上の単語群に、元の単語が含まれていればTrueと判定
        """
        tokens, predictions = self.mask_prediction(sentence)

        predictions = predictions.softmax(dim=2)

        judges = []
        for i in range(len(tokens) - 2):
            indices = (predictions[i][i + 1] >= threshold).nonzero()
            pred_top_word = self.tokenizer.convert_ids_to_tokens(indices)
            judges.append(tokens[i + 1] in pred_top_word)
            logger.info(f'{tokens[i + 1]}: {judges[-1]}')

        return all(judges)

校正の実行

クラスのインスタンス作成。
BERTのpretrainモデルは東北大乾研のbert-base-japanese-whole-word-maskingを使用。

import logging
from models.bert_proofreader import BertProofreader

logging.basicConfig(level=logging.INFO)

PRETRAINED_MODEL = 'bert-base-japanese-whole-word-masking'
proofreader = BertProofreader(PRETRAINED_MODEL)

TOP 5で判定
「絵」と「##い」がTop5に入らない。

proofreader.check_topk('絵がとても上手い', topk=5)

INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: False
False

loggerをDEBUGレベルに変更して上位5件の単語を確認してみる。
Wikipediaで学習したモデルでは、「歌がとても上手。」が一番もっともらしいよう。

logging.basicConfig(level=logging.DEBUG)
proofreader.check_topk('絵がとても上手い', topk=5)

INFO:models.bert_proofreader:絵: False
DEBUG:models.bert_proofreader:top k word=['歌', '野球', 'ゴルフ', 'テニス', 'サッカー']
INFO:models.bert_proofreader:が: True
DEBUG:models.bert_proofreader:top k word=['が', 'は', 'も', 'に', 'を']
INFO:models.bert_proofreader:とても: True
DEBUG:models.bert_proofreader:top k word=['とても', '一番', 'いちばん', 'かなり', '本当に']
INFO:models.bert_proofreader:上手: True
DEBUG:models.bert_proofreader:top k word=['上手', 'うま', '下手', '可愛', '得意']
INFO:models.bert_proofreader:##い: False
DEBUG:models.bert_proofreader:top k word=['。', 'です', '!!', '!', 'な']
False

TOP 100で判定。
TOP 100だと「絵がとても上手い」でもOKになる。

proofreader.check_topk('絵がとても上手い', topk=100)

INFO:models.bert_proofreader:絵: True
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: True
True

TOP 100で誤字ありの文の判定。
「絵がとても美味い」はTOP100でもNGになる。

proofreader.check_topk('絵がとても美味い', topk=100)
INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:美味: True
INFO:models.bert_proofreader:##い: True
False

閾値を使って判定。閾値を0.01に設定
「絵がとても上手い」でもNGになる。

proofreader.check_threshold('絵がとても上手い', threshold=0.01)

INFO:models.bert_proofreader:絵: False
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: False
False

閾値をどこまで下げれば「絵がとても上手い」がOKになるかと思って下げたが、かなり閾値が下になった。

proofreader.check_threshold('絵がとても上手い', threshold=0.0000001)

INFO:models.bert_proofreader:絵: True
INFO:models.bert_proofreader:が: True
INFO:models.bert_proofreader:とても: True
INFO:models.bert_proofreader:上手: True
INFO:models.bert_proofreader:##い: True
True

所感

  • Wikipediaでpretrainしたモデルで文章校正するのは難しそう
  • Masked Language Modelをドメインデータで追加pretrainした方が上手くいきそう
27
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
25

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?