from transformers import BertForNextSentencePrediction
nsp_bert = BertForNextSentencePrediction.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
nsp_bert.eval()
prompt = '私の家族は5人家族です。'
next_sentence = '家族は、父、母、兄、私、妹です。'
input_tensor = bert_tokenizer(prompt, next_sentence, return_tensors='pt')
print(input_tensor)
{'input_ids': tensor([[ 2, 1325, 5, 2283, 9, 76, 53, 2283, 2992, 8, 3, 2283,
9, 6, 800, 6, 968, 6, 1456, 6, 1325, 6, 4522, 2992,
8, 3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]])}
list(bert_tokenizer.get_vocab().items())[:5]
[('[PAD]', 0), ('[UNK]', 1), ('[CLS]', 2), ('[SEP]', 3), ('[MASK]', 4)]
output = nsp_bert(**input_tensor)
print(output)
NextSentencePredictorOutput(loss=None, logits=tensor([[14.3159, 2.9107]], grad_fn=), hidden_states=None, attentions=None)
torch.argmax(output.logits)
tensor(0)
うまく予測出来ました。
ここまでが簡単な Next sentence prediction の実装です。