##環境の準備
!pip install -q transformers==4.9.0
!pip install -q fugashi
!pip install -q ipadic
# 必要なモジュールのインストール
import torch
import transformers
##BERT モデルと Tokenizer の準備
from transformers import BertJapaneseTokenizer, BertForMaskedLM
今回は東北大学で開発されたモデルである cl-tohoku/bert-base-japanese-whole-word-masking を使用します。
# 分かち書きをするtokenizer
bert_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
lm_bert = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
##BERTの入力の準備
BertJapaneseTokenizer を使用して BERT の入力を用意します。
今回は、「彼は[MASK]として働いている。」という文章の [MASK] を予測させてみます。
# 分かち書きの実行
text = '彼は * として働いている。'
tokenized_text = bert_tokenizer(text)
print(tokenized_text)
MLM を実装するために予測したい単語(今回は *)を [MASK] トークンに置換します。
masked_idx = 2
# 予測させる単語をマスクする
tokenized_text[masked_idx] = '[MASK]'
print(tokenized_text)
[MASK] した単語を含んだ文字列を convert_tokens_to_ids() を使用して id に変換します。
indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)
print(indexed_tokens)
BERT モデルに使用できるように id 化した文字列を Tensor 型に変換します。
tokens_tensor = torch.tensor([indexed_tokens])
print(tokens_tensor)
tensor([[1, 1, 1, 1]])
##Mask language model の実行
BertForMaskedLM を使用して Mask language model を実行します。
モデルの構造を確認してみましょう。
.eval() で推論モードに切り替えることができ、こちらのコードを実行するとモデル構造が出力されます。
lm_bert.eval()
モデルが用意できたのでMLMの推論を行います。
outputs = lm_bert(tokens_tensor)
type(outputs)
len(outputs[0][0][0])
32000
今回の[MASK]した単語の予測値を取得します。
pred = outputs[0][0, masked_idx]
pred
tensor([-4.7888, 20.6928, -3.8209, ..., -8.5616, -4.3609, -0.3407],
grad_fn=)
この予測値に .topk() を用いることで予測値の内上位 ◯ に当てはまる予測値の値を取得することが出来ます。
pred.topk(2)
torch.return_types.topk(values=tensor([20.6928, 9.1267], grad_fn=), indices=tensor([ 1, 466]))
[MASK] した単語を上位 5 を出力する関数を定義します。
def predict_mlm(model, input, masked_idx):
with torch.no_grad():
outputs = model(input)
# 予測結果の上位5件を抽出
predictions = outputs[0][0, masked_idx].topk(5)
for i, idx in enumerate(predictions.indices):
index = idx.item()
print(index)
token = bert_tokenizer.convert_ids_to_tokens([index])[0]
print(f'順位:{i + 1}\n単語:{token}')
print('----')
predict_mlm(lm_bert, tokens_tensor, masked_idx)