0. はじめに
今回はBERTの解釈としてtransformers-interpretを試したのでメモがてら記載しておく。
- 動作環境
- OS : Windows10 pro
- python: 3.9.6
- transformers: 4.23.1
- Pytorch: 1.12.1 (+cu116)
- transformers-interpret: 0.9.5
- GPU: RTX 2060
- jupyter notebook(vscode)
1. transformers-interpretでBERTのマルチクラス分類を解釈する
transformers-interpretとは、PyTorch用のモデル解釈ライブラリ 「Captum」を使用したTransformers専用のライブラリ。
SHAP等の解釈モデルと違い、Transformers専用にチューンされているので使いやすい
1-1. Captumとは?
CaptumはPyTorch用のモデル解釈ライブラリで、torchvision(画像)やtorchtext(言語)などのドメイン固有ライブラリで構築されたモデルを素早く解釈可能。
この記事の本質ではないため細かくは説明しませんので、以下を読んでください。
1-2. transformers-interpretで解釈してみる
pip isnstall transformers_interpret
でインストールできます。
例題として前回書いた記事をそのまま流用する為、以下記事を最後まで実行していることが前提です
なお、@reluさんが下記記事をcolabでまとめてくれた為、そちらを参照してもいいです。
https://colab.research.google.com/drive/1WjHt3eByYAwg-JT7GnStIZlFp5TaQq0j?usp=sharing
※前回の記事を実行し、df_resultのデータフレームに以下のような結果が入っている状態
sentence | pred_label | true_label | |
---|---|---|---|
0 | もうプロポーズを待たない女たち | 0 | 0 |
0 | 写真魂のバトンリレー!GRデジタルをバトンに若き写真家たちの駅伝写真展がスタート | 1 | 1 |
まずは、学習済モデルとトークナイザ(cl-tohoku/bert-base-japanese-v
)を解釈用メソッドのMultiLabelClassificationExplainer
へ読み込ませて、解釈をさせる。
#辞書型表示にはpprintがおすすめ
import pprint
#from transformers_interpret import SequenceClassificationExplainer #これは2値分類
from transformers_interpret import MultiLabelClassificationExplainer
#ここで学習済モデルとトークナイザを指定して解釈(前回記事で定義)
explainer = MultiLabelClassificationExplainer(model, tokenizer)
#解釈させたいテキストを指定する。この例では0番目の要素を指定している
n = 0
text = df_result['sentence'].to_list()[n]
#ワードのスコアを解釈させる
word_attributions = explainer(text)
print(text) #原文表示
pprint.pprint(word_attributions) #結果表示
すると、以下のようにラベルと各トークンのスコアが表示される。
もうプロポーズを待たない女たち
{'LABEL_0': [('[CLS]', 0.0),
('もう', 0.2364731587674057),
('プロポーズ', 0.6601345432831441),
('を', -0.27212284601689574),
('待た', -0.22521582677921956),
('ない', -0.06555384693488837),
('女', 0.6103618238636512),
('たち', 0.08179825861601124),
('[SEP]', 0.0)],
'LABEL_1': [('[CLS]', 0.0),
('もう', -0.232783105861903),
('プロポーズ', -0.6474469309652334),
('を', 0.2667569326983013),
('待た', 0.24085367296863214),
('ない', 0.08413956066583539),
('女', -0.6166673578045454),
('たち', -0.1004820801597012),
('[SEP]', 0.0)],
'LABEL_2': [('[CLS]', 0.0),
('もう', -0.2579947004524142),
('プロポーズ', -0.7493263170843887),
('を', 0.31054406203828777),
('待た', 0.011560972069382598),
('ない', -0.16821512234084376),
('女', -0.4723670465258997),
('たち', 0.15475975088193888),
('[SEP]', 0.0)]}
で、これらを統合し可視化させるメソッドも当然ある。
html = explainer.visualize()
Attribution Labelがラベル名で、Predicted Labelが予測ラベル。
でこの例だと、Attribution Scoreが1.03で一番大きくなったラベル0が予測されたことになる。
さらに「プロポーズ」「女」という単語が予測に寄与していることも直感的に理解できる。
2. おわりに
BERTの可視化にはSHAP等もよく使われている(CaptumにもSHAPが入っているらしい)が、初手はかなり簡単に使えるtransformers-interpretを試してみても面白いと思う。
なお、今回は他クラス分類だったが2値分類はSequenceClassificationExplainerに変えるだけでいけるので、気になったら試してほしい。
それでは今回はここまで。