最後に今回使用したプログラムをまとめて記載しています。
Mac(Intel)環境の構築
BERT(文脈解析)だけでは日本語の理解をするには難しい場合があるため、Mecebを用いて形態素解析をし、その結果を文脈解析をすることで精度をあげることにする。
初めにmecabを実行するための環境を作るために下記を実行。
$ brew install mecab
$ brew install mecab-ipadic
$ pip install mecab-python
BERTで用いる環境を構築するために下記を実行。
$ pip install transformers
$ pip install fugashi
$ pip install ipadic
mecabで形態素解析のプログラムを作成する。
簡単なので一気に示す。
mecab_words.py
import MeCab
print("テキストを入力してください。")
text = input()
wakati = MeCab.Tagger('-Owakati')
print(wakati.parse(text))
初めに必要なライブラリをインポート
vector_words.py
from transformers import BertJapaneseTokenizer, BertModel
import torch
BERTのトークナイザーとモデルを準備
vector_words.py
model_name = 'cl-tohoku/bert-base-japanese'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
分析したいテキストの入力を待機
vector_words.py
print("テキストを入力してください、")
text = input()
テキストをトークン化
vector_words.py
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
BERTモデルで特徴ベクトルを取得
vector_words.py
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions
最初のレイヤーのAttentionを平均化して各トークンの重要度を求める
vector_words.py
avg_attention = torch.stack(tuple(attention[-1])).mean(dim=1).squeeze(0)
token_importance = avg_attention.sum(dim=0)
重要度の高い順にトークンと重要度を表示
vector_words.py
decoded_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
token_importance_list = list(zip(decoded_tokens, token_importance.tolist()))
sorted_token_importance = sorted(token_importance_list, key=lambda x: x[1], reverse=True)
for token, importance in sorted_token_importance:
print(f"Token: {token.ljust(15)} Importance: {importance:.4f}")
まとめ
vector_words.py
from transformers import BertJapaneseTokenizer, BertModel
import torch
model_name = 'cl-tohoku/bert-base-japanese'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
print("テキストを入力してください、")
text = input()
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions
avg_attention = torch.stack(tuple(attention[-1])).mean(dim=1).squeeze(0)
token_importance = avg_attention.sum(dim=0)
decoded_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
token_importance_list = list(zip(decoded_tokens, token_importance.tolist()))
sorted_token_importance = sorted(token_importance_list, key=lambda x: x[1], reverse=True)
for token, importance in sorted_token_importance:
print(f"Token: {token.ljust(15)} Importance: {importance:.4f}")
参考文献