0
0

BERTを用いて日本語文章の重要度を取得

Last updated at Posted at 2024-06-20

最後に今回使用したプログラムをまとめて記載しています。

Mac(Intel)環境の構築

BERT(文脈解析)だけでは日本語の理解をするには難しい場合があるため、Mecebを用いて形態素解析をし、その結果を文脈解析をすることで精度をあげることにする。

初めにmecabを実行するための環境を作るために下記を実行。

$ brew install mecab
$ brew install mecab-ipadic
$ pip install mecab-python

BERTで用いる環境を構築するために下記を実行。

$ pip install transformers
$ pip install fugashi 
$ pip install ipadic

mecabで形態素解析のプログラムを作成する。
簡単なので一気に示す。

mecab_words.py
import MeCab

print("テキストを入力してください。")
text = input()
wakati = MeCab.Tagger('-Owakati')
print(wakati.parse(text))

初めに必要なライブラリをインポート

vector_words.py
from transformers import BertJapaneseTokenizer, BertModel
import torch

BERTのトークナイザーとモデルを準備

vector_words.py
model_name = 'cl-tohoku/bert-base-japanese'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

分析したいテキストの入力を待機

vector_words.py
print("テキストを入力してください、")
text = input()

テキストをトークン化

vector_words.py
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)

BERTモデルで特徴ベクトルを取得

vector_words.py
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions

最初のレイヤーのAttentionを平均化して各トークンの重要度を求める

vector_words.py
avg_attention = torch.stack(tuple(attention[-1])).mean(dim=1).squeeze(0)
token_importance = avg_attention.sum(dim=0)

重要度の高い順にトークンと重要度を表示

vector_words.py
decoded_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
token_importance_list = list(zip(decoded_tokens, token_importance.tolist()))

sorted_token_importance = sorted(token_importance_list, key=lambda x: x[1], reverse=True)
for token, importance in sorted_token_importance:
    print(f"Token: {token.ljust(15)} Importance: {importance:.4f}")

まとめ

vector_words.py
from transformers import BertJapaneseTokenizer, BertModel
import torch

model_name = 'cl-tohoku/bert-base-japanese'
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

print("テキストを入力してください、")
text = input()

inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)

with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions

avg_attention = torch.stack(tuple(attention[-1])).mean(dim=1).squeeze(0)
token_importance = avg_attention.sum(dim=0)

decoded_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
token_importance_list = list(zip(decoded_tokens, token_importance.tolist()))

sorted_token_importance = sorted(token_importance_list, key=lambda x: x[1], reverse=True)
for token, importance in sorted_token_importance:
    print(f"Token: {token.ljust(15)} Importance: {importance:.4f}")

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0