13
8

More than 1 year has passed since last update.

HuggingFace(BERT)モデルをtransformers-interpretで解釈する

Last updated at Posted at 2022-11-06

0. はじめに

今回はBERTの解釈としてtransformers-interpretを試したのでメモがてら記載しておく。

  • 動作環境
    • OS : Windows10 pro
    • python: 3.9.6 
    • transformers: 4.23.1
    • Pytorch: 1.12.1 (+cu116)
    • transformers-interpret: 0.9.5
    • GPU: RTX 2060
    • jupyter notebook(vscode)

1. transformers-interpretでBERTのマルチクラス分類を解釈する

transformers-interpretとは、PyTorch用のモデル解釈ライブラリ 「Captum」を使用したTransformers専用のライブラリ。

SHAP等の解釈モデルと違い、Transformers専用にチューンされているので使いやすい

1-1. Captumとは?

CaptumはPyTorch用のモデル解釈ライブラリで、torchvision(画像)やtorchtext(言語)などのドメイン固有ライブラリで構築されたモデルを素早く解釈可能。
この記事の本質ではないため細かくは説明しませんので、以下を読んでください。

1-2. transformers-interpretで解釈してみる

pip isnstall transformers_interpretでインストールできます。

例題として前回書いた記事をそのまま流用する為、以下記事を最後まで実行していることが前提です

なお、@reluさんが下記記事をcolabでまとめてくれた為、そちらを参照してもいいです。
https://colab.research.google.com/drive/1WjHt3eByYAwg-JT7GnStIZlFp5TaQq0j?usp=sharing

※前回の記事を実行し、df_resultのデータフレームに以下のような結果が入っている状態

sentence pred_label true_label
0 もうプロポーズを待たない女たち 0 0
0 写真魂のバトンリレー!GRデジタルをバトンに若き写真家たちの駅伝写真展がスタート 1 1

まずは、学習済モデルとトークナイザ(cl-tohoku/bert-base-japanese-v)を解釈用メソッドのMultiLabelClassificationExplainerへ読み込ませて、解釈をさせる。

BERTで学習済モデルを解釈させる
#辞書型表示にはpprintがおすすめ
import pprint
#from transformers_interpret import SequenceClassificationExplainer #これは2値分類
from transformers_interpret import MultiLabelClassificationExplainer

#ここで学習済モデルとトークナイザを指定して解釈(前回記事で定義)
explainer = MultiLabelClassificationExplainer(model, tokenizer)

#解釈させたいテキストを指定する。この例では0番目の要素を指定している
n = 0
text = df_result['sentence'].to_list()[n]

#ワードのスコアを解釈させる
word_attributions = explainer(text)

print(text) #原文表示
pprint.pprint(word_attributions) #結果表示

すると、以下のようにラベルと各トークンのスコアが表示される。

実行結果
もうプロポーズを待たない女たち
{'LABEL_0': [('[CLS]', 0.0),
             ('もう', 0.2364731587674057),
             ('プロポーズ', 0.6601345432831441),
             ('を', -0.27212284601689574),
             ('待た', -0.22521582677921956),
             ('ない', -0.06555384693488837),
             ('女', 0.6103618238636512),
             ('たち', 0.08179825861601124),
             ('[SEP]', 0.0)],
 'LABEL_1': [('[CLS]', 0.0),
             ('もう', -0.232783105861903),
             ('プロポーズ', -0.6474469309652334),
             ('を', 0.2667569326983013),
             ('待た', 0.24085367296863214),
             ('ない', 0.08413956066583539),
             ('女', -0.6166673578045454),
             ('たち', -0.1004820801597012),
             ('[SEP]', 0.0)],
 'LABEL_2': [('[CLS]', 0.0),
             ('もう', -0.2579947004524142),
             ('プロポーズ', -0.7493263170843887),
             ('を', 0.31054406203828777),
             ('待た', 0.011560972069382598),
             ('ない', -0.16821512234084376),
             ('女', -0.4723670465258997),
             ('たち', 0.15475975088193888),
             ('[SEP]', 0.0)]}

で、これらを統合し可視化させるメソッドも当然ある。

可視化
html = explainer.visualize()

Attribution Labelがラベル名で、Predicted Labelが予測ラベル。
でこの例だと、Attribution Scoreが1.03で一番大きくなったラベル0が予測されたことになる。
さらに「プロポーズ」「女」という単語が予測に寄与していることも直感的に理解できる。

スクリーンショット 2022-11-06 190421.png

2. おわりに

BERTの可視化にはSHAP等もよく使われている(CaptumにもSHAPが入っているらしい)が、初手はかなり簡単に使えるtransformers-interpretを試してみても面白いと思う。
なお、今回は他クラス分類だったが2値分類はSequenceClassificationExplainerに変えるだけでいけるので、気になったら試してほしい。

それでは今回はここまで。

13
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
8