3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

BERTの事前学習 Mask language modelの実装

Posted at

##環境の準備

!pip install -q transformers==4.9.0
!pip install -q fugashi
!pip install -q ipadic
# 必要なモジュールのインストール
import torch
import transformers

##BERT モデルと Tokenizer の準備

from transformers import BertJapaneseTokenizer, BertForMaskedLM

今回は東北大学で開発されたモデルである cl-tohoku/bert-base-japanese-whole-word-masking を使用します。

# 分かち書きをするtokenizer
bert_tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
lm_bert = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

##BERTの入力の準備
BertJapaneseTokenizer を使用して BERT の入力を用意します。
今回は、「彼は[MASK]として働いている。」という文章の [MASK] を予測させてみます。

# 分かち書きの実行

text = '彼は * として働いている。'

tokenized_text = bert_tokenizer(text)
print(tokenized_text)

MLM を実装するために予測したい単語(今回は *)を [MASK] トークンに置換します。

masked_idx = 2

# 予測させる単語をマスクする
tokenized_text[masked_idx] = '[MASK]'
print(tokenized_text)

[MASK] した単語を含んだ文字列を convert_tokens_to_ids() を使用して id に変換します。

indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)

print(indexed_tokens)

BERT モデルに使用できるように id 化した文字列を Tensor 型に変換します。

tokens_tensor = torch.tensor([indexed_tokens])

print(tokens_tensor)

tensor([[1, 1, 1, 1]])

##Mask language model の実行
BertForMaskedLM を使用して Mask language model を実行します。

モデルの構造を確認してみましょう。
.eval() で推論モードに切り替えることができ、こちらのコードを実行するとモデル構造が出力されます。

lm_bert.eval()

モデルが用意できたのでMLMの推論を行います。

outputs = lm_bert(tokens_tensor)
type(outputs)
len(outputs[0][0][0])

32000

今回の[MASK]した単語の予測値を取得します。

pred = outputs[0][0, masked_idx]
pred

tensor([-4.7888, 20.6928, -3.8209, ..., -8.5616, -4.3609, -0.3407],
grad_fn=)

この予測値に .topk() を用いることで予測値の内上位 ◯ に当てはまる予測値の値を取得することが出来ます。

pred.topk(2)

torch.return_types.topk(values=tensor([20.6928, 9.1267], grad_fn=), indices=tensor([ 1, 466]))

[MASK] した単語を上位 5 を出力する関数を定義します。

def predict_mlm(model, input, masked_idx):
    with torch.no_grad():
        outputs = model(input)
        # 予測結果の上位5件を抽出
        predictions = outputs[0][0, masked_idx].topk(5) 

    for i, idx in enumerate(predictions.indices):
        index = idx.item()
        print(index)
        token = bert_tokenizer.convert_ids_to_tokens([index])[0]
        print(f'順位:{i + 1}\n単語:{token}')
        print('----')
predict_mlm(lm_bert, tokens_tensor, masked_idx)

1
順位:1
単語:[UNK]

466
順位:2
単語:学校

20770
順位:3
単語:物事

959
順位:4
単語:諸

2633
順位:5
単語:男子

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?