LoginSignup
5
3

More than 1 year has passed since last update.

BERTを用いて文章の穴埋め問題を解く

Last updated at Posted at 2022-03-25

はじめに

BERTは自然言語処理の問題を解くための言語モデルとして2018年にGoogleによって発表されました。BERTの学習は事前学習とファインチューニングの2つからなりますが、今回はGoogle Colaboratory上でファインチューニングなしでBERTの事前学習済みモデルを使用して、文章の穴埋めをしたいと思います。

簡単な文章の穴埋め

Transformersでは様々な言語の事前学習モデルが使え、東北大学が提供している日本語の事前学習モデルも使用することができます。
今回はオープンソースの形態素解析器である「MeCab」のラッパーライブラリである「fugashi」を使用します。

from transformers import BertJapaneseTokenizer, BertForMaskedLM
model_name ='cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
bert_mlm = BertForMaskedLM.from_pretrained(model_name)
bert_mlm = bert_mlm.cuda()

ではまず、一部をマスク付けした文章を用意してトークン化します。

text = '私の今日の昼ご飯は[MASK]です。'
tokens = tokenizer.tokenize(text)
print(tokens)

そうすると以下のように分割されました。

['', '', '今日', '', '', 'ご飯', '', '[MASK]', 'です', '']

Bertは[MASK]のなかに入るトークンを予測することで事前学習を行うモデルです。
今回も上の文章の[MASK]に入るトークンを予測したいのですが、Bertではテキストを一度符号化する必要があります。

input_ids = tokenizer.encode(text, return_tensors='pt')
print(input_ids)

encode()はテキストをBertに入力することができるようにテキストを符号化する役割を担っています。

tensor([[ 2,  1325, 5,  3246,  5,  5228, 27073,  9,  4,  2992, 8,  3]])

これをもう一度テキストに変換してみると

tokenizer.convert_ids_to_tokens(input_ids)

以下のように先頭に[CLS]、一番最後に[SEP]が追加されて出力されることがわかります。

['[CLS]', '', '', '今日', '', '', 'ご飯', '', '[MASK]', 'です', '', '[SEP]']

では、取得したID列をBertに入力し、予測順位が最も高かったトークンで[MASK]を置き換えていきます。

#BertにID列を入力し、分類スコアを得る
with torch.no_grad():
  output = bert_mlm(input_ids=input_ids)
  scores = output.logits

#[MASK]の位置を調べる(MASKの数値は4だった)
mask_position = input_ids[0].tolist().index(4)

#[MASK]の位置で最も可能性が高いトークンを取り出し、変換する
id_best = scores[0, mask_position].argmax(-1).item()
token_best = tokenizer.convert_ids_to_tokens(id_best)
text = text.replace('[MASK]',token_best)
print(text)

そうするとこのように出力されました。

私の今日の昼ご飯はご飯です。

MASK部分には食べ物が入るはずですから正しく予測することはできているようですが、試しに最も可能性の高いものだけでなく、トップ10までを見てみます。

top_mask = scores[0, mask_position].topk(10)
tokens = tokenizer.convert_ids_to_tokens(top_mask.indices)
print(tokens)

結果は以下のようになりました。トップ10を見てみると文章中でも意味が通りそうなのは「ご飯」、「飯」、「米」くらいでしょうか。

['ご飯', '朝食', '', '毎日', '夕食', '', '幸せ', '', '食事', '食べ物']

ここまでの処理を関数化します。

def mask_predict(text, tokenizer, bert_mlm):
  #テキストを符号化し、スコアを取得
  input_ids = tokenizer.encode(text, return_tensors='pt')
  input_ids = input_ids.cuda()
  with torch.no_grad():
    output = bert_mlm(input_ids=input_ids)
    scores = output.logits
  
  #[MASK]の位置を調べる
  mask_position = input_ids[0].tolist().index(4)

  #スコア順位が上位10位までトークンを求める
  top_mask = scores[0, mask_position].topk(10)
  tokens = tokenizer.convert_ids_to_tokens(top_mask.indices)

  #[MASK]を求めたトークンで置き換える
  top_text = []
  for token in tokens:
    token = token.replace('##', '')
    top_text.append(text.replace('[MASK]',token, 1))

  return top_text

センター試験問題を解く

では先ほど作った関数を少し改造し、BERTにセンター試験の問題を解かせたいと思います。
Beam Search(ビームサーチ)というアルゴリズムを使って複数の空欄に入る言葉を推論します。
ビームサーチは各単語を予測する際に、スコアが高い順に候補となる単語をn個(複数個)選んでいく方法です。

今回は2020年度の日本史の第2問の問1を解かせます。
スクリーンショット 2022-03-26 011134.png

def mask_predict(text, tokenizer, bert_mlm):
  #テキストを符号化し、スコアを取得
  input_ids = tokenizer.encode(text, return_tensors='pt')
  input_ids = input_ids.cuda()
  with torch.no_grad():
    output = bert_mlm(input_ids=input_ids)
  scores = output.logits
  
  #[MASK]の位置を調べる
  mask_position = input_ids[0].tolist().index(4)

  #スコア順位が上位5位までトークンとスコアを求める
  top_mask = scores[0, mask_position].topk(5)
  tokens = tokenizer.convert_ids_to_tokens(top_mask.indices)
  top_scores = top_mask.values.cpu().numpy()

  #[MASK]を求めたトークンで置き換える
  top_text = []
  for token in tokens:
    token = token.replace('##', '')
    top_text.append(text.replace('[MASK]',token, 1))

  return top_text, top_scores

def beam_search(text, tokenizer, bert_mlm):
  top_text = [text]
  top_scores = np.array([0])
  for _ in range(text.count('[MASK]')):
    #textを追加していく
    text_candidates = []
    #スコアを追加していく
    score_candidates = []
    for text_mask , score in zip(top_text, top_scores):
      top_text_inner, top_scores_inner = mask_predict(text_mask, tokenizer, bert_mlm)
    text_candidates.extend(top_text_inner)
    score_candidates.append(score + top_scores_inner)

    #穴埋めされた文章の中から合計スコアの高いものを選ぶ
    score_candidates = np.hstack(score_candidates)
    idx_list = score_candidates.argsort()[::-1][:5]
    top_text = [text_candidates[idx] for idx in idx_list]
    top_scores = score_candidates[idx_list]
  
  return top_text

これで穴埋めを行い、スコアが高い上位5位までの文章を表示すると

律令国家成立当初の東北地方以北や九州南部以南の地域は,いまだ中央政府の支配下に組み込まれておらず,辺境と位置づけられた。辺境の人々は東北地方では蝦夷,九州南部では隼人などとよばれ,中央政府との対立を経ながら徐々にその支配下に組み込まれていった。8世紀初め,中央政府は隼人の抵抗を抑え,九州南部に薩摩国,ついで王国を設置した。ここでは, 720年に隼人が国司を殺害するという反乱が起きたが,それが鎮圧された後は,隼人の大きな抵抗はみられなくなった。一方,東北地方に対しても,大化改新後,中央政府による支配領域拡大の動きが本格化した。東北地方の太平洋側では,改新後に陸奥国が設置されたと推測され,蝦夷支配を進めるために城柵が設けられた。その一つが724年に設置された多賀城で, ここには陸奥国府と国分寺がおかれた 8世紀から9世紀にかけて造られた城柵については, 発掘調査の成果から, 東北地方以外の国府との違いや共通点がわかってきている。
律令国家成立当初の東北地方以北や九州南部以南の地域は,いまだ中央政府の支配下に組み込まれておらず,辺境と位置づけられた。辺境の人々は東北地方では蝦夷,九州南部では隼人などとよばれ,中央政府との対立を経ながら徐々にその支配下に組み込まれていった。8世紀初め,中央政府は隼人の抵抗を抑え,九州南部に薩摩国,ついで王国を設置した。ここでは, 720年に隼人が国司を殺害するという反乱が起きたが,それが鎮圧された後は,隼人の大きな抵抗はみられなくなった。一方,東北地方に対しても,大化改新後,中央政府による支配領域拡大の動きが本格化した。東北地方の太平洋側では,改新後に陸奥国が設置されたと推測され,蝦夷支配を進めるために城柵が設けられた。その一つが724年に設置された多賀城で, ここには陸奥国府と[UNK]がおかれた 8世紀から9世紀にかけて造られた城柵については, 発掘調査の成果から, 東北地方以外の国府との違いや共通点がわかってきている。
律令国家成立当初の東北地方以北や九州南部以南の地域は,いまだ中央政府の支配下に組み込まれておらず,辺境と位置づけられた。辺境の人々は東北地方では蝦夷,九州南部では隼人などとよばれ,中央政府との対立を経ながら徐々にその支配下に組み込まれていった。8世紀初め,中央政府は隼人の抵抗を抑え,九州南部に薩摩国,ついで王国を設置した。ここでは, 720年に隼人が国司を殺害するという反乱が起きたが,それが鎮圧された後は,隼人の大きな抵抗はみられなくなった。一方,東北地方に対しても,大化改新後,中央政府による支配領域拡大の動きが本格化した。東北地方の太平洋側では,改新後に陸奥国が設置されたと推測され,蝦夷支配を進めるために城柵が設けられた。その一つが724年に設置された多賀城で, ここには陸奥国府と政庁がおかれた 8世紀から9世紀にかけて造られた城柵については, 発掘調査の成果から, 東北地方以外の国府との違いや共通点がわかってきている。
律令国家成立当初の東北地方以北や九州南部以南の地域は,いまだ中央政府の支配下に組み込まれておらず,辺境と位置づけられた。辺境の人々は東北地方では蝦夷,九州南部では隼人などとよばれ,中央政府との対立を経ながら徐々にその支配下に組み込まれていった。8世紀初め,中央政府は隼人の抵抗を抑え,九州南部に薩摩国,ついで王国を設置した。ここでは, 720年に隼人が国司を殺害するという反乱が起きたが,それが鎮圧された後は,隼人の大きな抵抗はみられなくなった。一方,東北地方に対しても,大化改新後,中央政府による支配領域拡大の動きが本格化した。東北地方の太平洋側では,改新後に陸奥国が設置されたと推測され,蝦夷支配を進めるために城柵が設けられた。その一つが724年に設置された多賀城で, ここには陸奥国府と国府がおかれた 8世紀から9世紀にかけて造られた城柵については, 発掘調査の成果から, 東北地方以外の国府との違いや共通点がわかってきている。
律令国家成立当初の東北地方以北や九州南部以南の地域は,いまだ中央政府の支配下に組み込まれておらず,辺境と位置づけられた。辺境の人々は東北地方では蝦夷,九州南部では隼人などとよばれ,中央政府との対立を経ながら徐々にその支配下に組み込まれていった。8世紀初め,中央政府は隼人の抵抗を抑え,九州南部に薩摩国,ついで王国を設置した。ここでは, 720年に隼人が国司を殺害するという反乱が起きたが,それが鎮圧された後は,隼人の大きな抵抗はみられなくなった。一方,東北地方に対しても,大化改新後,中央政府による支配領域拡大の動きが本格化した。東北地方の太平洋側では,改新後に陸奥国が設置されたと推測され,蝦夷支配を進めるために城柵が設けられた。その一つが724年に設置された多賀城で, ここには陸奥国府とそれがおかれた 8世紀から9世紀にかけて造られた城柵については, 発掘調査の成果から, 東北地方以外の国府との違いや共通点がわかってきている。

この問題の解答は4番で空欄アには大隅国、空欄イには鎮守府が入ります。
生成された文章を見てみるとどの文章も1番目の空欄には「王国」、2番目の空欄には「ここ」そして、3番目の空欄には「国分寺」、「政庁」、「国府」、「それ」などと推論されていることが分かります。
正解はできませんでしたが、どれも意味が通る文章になっています。
(参考文献)
BERTによる自然言語処理入門:Transformersを使った実践プログラミング オーム社出版

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3