※3日遅れての投稿(´・ω・`) 家の事が信じられないぐらい忙しかったので、許しください。(これが 師走)
こんにちは、
かなり昔から話題になっている「XAI」について、大枠は理解しているつもりだが、
システム開発の案件に携わる事が多くて(設計とか)、実装にトライすることがおざなりになっていたので実装にチャレンジしてみました。
XAIの雑な説明
いわゆる 「AIのブラックボックス問題」というのがあり、
ヒトが見て解釈しやすいようなモデル(決定木、統計モデリング)ではなくて、
アンサンブルや、ディープラーニングのような複雑なやつは学習過程も複雑なので、それがゆえに
その推論や、認識、予測がブラックボックスとなってしまい、AIが何故このような判断をしたのか確認できると良いよね
ってのがXAIです。
→ まぁ確かに、クレジットカード審査とか、画像分類とかで、理由が知りたいよねぇ
XAIによるAIの説明について
実は「大局説明」という話と「局所説明」の2つに分類されます。
全体的な概要なのか、とある入力データに対する予測結果に対しての細かい説明かどうかの違いです。
大局説明とは
目的
「AIモデルの全体的な振る舞い」を理解する事
一言で言うと
モデルにとって重要な特徴量を明示する方法
具体例
売上予測などを考えた時に、このモデルは「天気」をめちゃくちゃ重要としていますよ!
天気が重要なので、天気予報していきましょう!! モデルの大きな特徴を説明することができるんですね。
※天気は雑すぎるww
局所説明とは
目的
「個々の予測結果の判断根拠を理解すること」
一言で言うと
モデルの出力結果に寄与した特徴量を明示する方法
(特定の入力に対する“予測結果の説明”を提示することで説明とする方法。)
具体例
例えば、クレジットカードの審査で「不可」と判断された時に、顧客に「***が**以下だったので・・すいません!!」と説明ができるんですね。
今回の実装
はじめに
XAIについてとてもキレイにまとめられた本があったので、この本を参考に実装しました。
XAI〈説明可能なAI〉 そのとき人工知能はどう考えたのか? )
→ 久しぶりに体系整理されているなと感じた本で、知りたかった事が詳細にかかれており、とても良い本です。
詳細を知りたい方は本購入をオススメします。
今回の実装でやること
自然言語処理でおなじみの、「Livedoorコーパス]を使ってトピック分類を行います。
何故、このモデルがこのトピックに分類したのか。重要と判定したWordは何だったのか。 
これを可視化してみましょう。
手順
- ライブラリの準備
- Livedoorコーパスのデータを入手
- データの読み込み&前処理
- BERTモデル/ okenizerの読み込み
- 学習データの生成
- XAI
上記の手順でやっていきます。
1. ライブラリの準備
おなじみのやつです。 pipしましょう。
pip install -U pip
pip install torch==1.7.0 fugashi==1.0.5 ipadic==1.0.0 \
    transformers==4.0.0 lime==0.2.0.1 captum==0.3.0 \
    scikit-learn==0.22.1 numpy==1.19.4
2. Livedoorコーパスのデータを入手
次に、Livedoorコーパスを入手します。
株式会社ロンウィットさん のサイトより入手可能です。※ありがとうございます。
3. データの読み込み&前処理
Livedoorデータの読み込み部分の 'ダウンロード/text/' は適時変更してください。
import os
import re
import sys
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
from lime.lime_text import LimeTextExplainer
import numpy as np
import torch
from torch.nn import functional as F
import transformers
from sklearn.model_selection import train_test_split
class LivedoorData(object):
    def __init__(self, data_dir='ダウンロード/text/'):
        self.data_dir = data_dir
    def _get_texts(self, category):
        category_dir = os.path.join(self.data_dir, category)
        filenames = [
            filename for filename in os.listdir(category_dir)
            if re.search(r'^.*\.txt$', filename)
        ]
        texts = []
        for filename in filenames:
            fpath = os.path.join(category_dir, filename)
            with open(fpath, 'r', encoding='utf-8') as fp:
                fp.readline()
                fp.readline()
                buf = []
                for line in fp:
                    buf.append(line.strip().replace(r'\s', ''))
            texts.append(' '.join(buf))
        return texts
    def get_categories(self):
        categories = [
            category for category in os.listdir(self.data_dir)
            if os.path.isdir(os.path.join(self.data_dir, category))
               and not re.search(r'^\.', category)
        ]
        return categories
    
    def read(self, categories=None):
        all_categories = self.get_categories()
        if categories:
            categories = [
                category for category in categories
                if category in all_categories
            ]
        else:
            categories = all_categories
        
        self.X, self.y = [], []
        for category in categories:
            texts = self._get_texts(category)
            self.X.extend(texts)
            self.y.extend([category] * len(texts))
        
    def get_data(self):
        if not self.X and not self.y:
            self.read()
        return self.X, self.y
# ターゲットカテゴリ
target_categories = ['sports-watch', 'kaden-channel', 'movie-enter']
n_category = len(target_categories)
# データ読み込み
livedoor = LivedoorData()
livedoor.read(categories=target_categories)
X, y = livedoor.get_data()
# カテゴリ名をINDEX化
y_indices = np.array([target_categories.index(v) for v in y])
4. BERTモデル/ okenizerの読み込み
おなじみのBERTを使います。
model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForSequenceClassification \
    .from_pretrained(model_name, num_labels=len(target_categories),
                     output_attentions=True)
model = model.to(device)
5. 学習データの制定
max_length = 256
tokenized = tokenizer.batch_encode_plus(
    X, padding=True, truncation=True,
    max_length=max_length, return_tensors='pt')
tensor_X = tokenized['input_ids']
tensor_y = torch.tensor(y_indices)
tensor_mask = tokenized['attention_mask']
X_train, X_valid, y_train, y_valid, mask_train, mask_valid = train_test_split(
    tensor_X, tensor_y, tensor_mask, train_size=0.75, random_state=42
)
6. XAIする
さて、やっとでXAIできるところまで来ました。
トピック分類の判断で重要度の高いWordを可視化しましょう!!
n_epoch = 3
batch_size = 16
def predict_proba(texts, model=model, tokenizer=tokenizer,
                  max_length=max_length, batch_size=batch_size):
    tokenized = tokenizer.batch_encode_plus(
        texts, padding=True, truncation=True,
        max_length=max_length, return_tensors='pt')
    ids, masks = [tokenized[key]
                  for key in ['input_ids', 'attention_mask']]    
    n_batch = ids.shape[0] // batch_size + 1    
    list_prob = []
    for i_batch in range(n_batch):
        idx_from = i_batch * batch_size
        idx_to = (i_batch + 1) * batch_size
        ids_batch = ids[idx_from:idx_to].to(device)
        mask_batch = masks[idx_from:idx_to].to(device)
        logits = model(ids_batch, attention_mask=mask_batch)['logits']
        prob = F.softmax(logits, dim=1).cpu().detach().numpy()
        list_prob.append(prob)
    return np.vstack(list_prob)
word_tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, do_subword_tokenize=False)
explainer = LimeTextExplainer(
    class_names=target_categories,
    split_expression=word_tokenizer.tokenize,
    mask_string=tokenizer.pad_token,
    random_state=0)
sample_text = 'たった3回のタッチで優勝。主役であり、最高の補強'
exp_result = explainer.explain_instance(sample_text, predict_proba, num_features=20,
                                        labels=[target_categories.index('sports-watch')])
## 私の息子がサッカー部のレギュラーになれますように
exp_result.show_in_notebook(text=True)
可視化の結果
きゃー素敵だわ!!!
最後に
今回はXAIのLIMEを利用して、判断根拠の可視化っぽい事をやってみました。
この後、 BERTモデルをファインチューニングした後で、同じように可視化をするとより、一層判断根拠がそれっぽく値がでてくるらしいので、引き続きトライしたいなと
面白いと感じました。
XAIはまだまだ研究段階の分野であるので、引き続きウォッチしたいと思います!
追記
ファインチューニングした結果、更にそれっぽい結果が可視化できました。 すごい
参考リンク
lime
https://pypi.org/project/lime/0.2.0.1/
livedoorニュースコーパス
https://www.rondhuit.com/download/ldcc-20140209.tar.gz

