LoginSignup
78
80

More than 1 year has passed since last update.

代理モデルによる機械学習モデルの説明

Last updated at Posted at 2021-09-25

はじめに

代理モデル (surrogate model) とは複雑な機械学習モデル(e.g., DNN, GBDT)を近似する簡単なモデル(e.g., パラメタ数の少ないDNN, 単純決定木, etc)のことを指します.代理モデルは推論の高速化・機械学習モデルの説明などさまざまな用途に使われています.

この記事では代理モデルによる機械学習モデルの説明をハンズオン的に紹介します.これは非常にシンプルかつ柔軟な手法ですが,アドホックな部分が多いためかハンズオン的な解説は見当たりませんでした.Christoph Molnar による Interpretable Machine LearningGlobal Surrogate に概要は示されているので機械学習に詳しい人はこちらを読めば十分かもしれません.関連するライブラリに LIMETreeSurrogate がありますが,わたしがこの手法を使うときはアドホック感を大事にしたいときなので,あまりライブラリは使いません.

代理モデルで説明する手続き

代理モデルによる説明を得るには以下の手続きを踏みます.説明したいモデル model が与えられているところからスタートします.

  1. 適当な入力データ X を準備する
  2. X に摂動を加えて新しいデータ X_perturbed を作る
  3. X_perturbedmodel に食わせて出力 y_perturbed を作る
  4. (X_perturbed, y_perturbed) を使って代理モデル explainer を学習する
  5. explainer を目で見て解釈する

どのようなデータを準備するか (Step 1),どのような摂動を加えるか (Step 2),どのモデルで解釈するか (Step 4),どう解釈するか (Step 5),とほぼ全ステップに自由度があります.摂動部分について代表的な例を2つ紹介します.

  • model を学習したデータをそのまま X_perturbed とするのが Global Surrogate と呼ばれる手法の基本的な設定です.これは model 全体に一貫する説明が得られるのがよいのですが,問題が複雑な場合に代理モデルと model の差が大きくなってしまうので使えなくなります(∵全体を代理モデルで近似できるなら最初から代理モデルを使えば良い).Global Surrogate でよいかは代理モデルとオリジナルモデルの出力の一致度を評価すれば判定できます.

  • 1つのサンプルのデータを X とし,それをランダムに摂動する(e.g., ノイズを足す・値を置き換える)のが Local Surrogate (LIME) と呼ばれる手法の基本的な設定です.複雑なモデルでも局所的にはシンプルなモデルで近似できることが多いので広く使える手法ですが,選んだサンプル周辺でしか通用しない「場当たり的な説明」が返ってくる可能性に注意しないといけません.構築した代理モデルを他のサンプルに当てはめるとどの程度普遍的な説明が得られたかが判定できます.

ハンズオン

ここでは Adult Dataset (https://archive.ics.uci.edu/ml/datasets/adult) に対する機械学習モデルを決定木を代理モデルにして説明してみます.このデータセットはアメリカの国勢調査のデータから作られており,幾つかの特徴量から収入が50k以下 or 50k超過を当てるのがタスクです.

前処理

まずはデータを取得します.

!wget https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data

適当に必要なもの・必要になるものをインポート.

import category_encoders as ce
import matplotlib.pyplot
import numpy as np
import pandas as pd
import random
import sklearn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.tree import DecisionTreeClassifier

matplotlib.pyplot.rcParams['figure.figsize'] = (15.0, 15.0)

データを読み込んで列名をつけて予測先の列を分離しておきます.

X = pd.read_csv("adult.data", header=None)

feature_names = ["age", "workclass", "fnlwgt", "education", "education-num",
                 "marital-status", "occupation", "relationship", "race", 
                 "sex", "capital-gain", "capital-loss", "hours-per-week",
                 "native-country", "income"]
X.columns = feature_names

y = X["income"]
X = X.drop(columns=["income"])

X_train, X_test, y_train, y_test = train_test_split(X, y)

以降では2種類のモデルを学習することになるので,使い回しやすいようにデータ前処理をクラスにまとめておきます.ここではカテゴリカルな列を one-hot encoding するだけの簡単なものを使います.

class AdultPreprocessor(BaseEstimator, TransformerMixin):
    def __init__(self):
        super().__init__()
        categorical_features = ["workclass", "education", "marital-status", "occupation", 
                                "relationship", "race", "sex", "native-country"]
        self.encoder = ce.OneHotEncoder(
            cols=categorical_features, use_cat_names=True)

    def fit(self, X, y=None):
        X = self.encoder.fit_transform(X, y)
        self.feature_names = X.columns
        return self

    def transform(self, X):
        X = self.encoder.transform(X)
        return X

オリジナルモデルの準備

解釈したいオリジナルのモデルを作ります.ここでは簡単なロジスティック回帰モデルとしましたが,実用的にはもっと複雑なモデルになるはずです.

model = Pipeline([("preprocess", AdultPreprocessor()), ("classification", LogisticRegression())])

model.fit(X_train, y_train)

代理モデルの構築

さて,ここからが本題です.データを準備して摂動を加えて新しいデータを作ります.Global Surrogateの設定(元のデータをそのまま使う)でもよいのですが,そうでない例を示すためにここでは元のデータの各列をランダムシャッフルします.これは各特徴量を独立とみなして分布をモデル化したことに相当するので,大雑把には「データ全体の中央値的な点」に対する振る舞いを見ることになります.この摂動はカテゴリカルなデータにも使える・欠損を生じさせない,実装が楽などの特徴があって使いやすいため,わたしは好んで使っています.特徴量間に強い相関がある場合にはナンセンスな設定になることに注意してください.

X_perturbed = X.copy()
for c in X_perturbed.columns:
    X_perturbed[c] = X_perturbed[c].sample(frac=1)
y_perturbed = model.predict(X_perturbed)

生成したデータに対して解釈しやすいモデルを当てはめます.ここでは決定木です.パラメタはデフォルトのままにしていますが本来はクロスバリデーションで決めるほうがよいです.いくらでも摂動によって学習データを生成できるので気兼ねなくデータをクロスバリデーションに回せます.

preprocessor = model.named_steps["preprocess"]
X_perturbed_ = preprocessor.transform(X_perturbed)
explainer = DecisionTreeClassifier()
explainer.fit(X_perturbed_, y_perturbed)

代理モデルがオリジナルのモデルをうまく近似しているかをチェックします.

print(classification_report(explainer.predict(X_perturbed_), y_perturbed))
              precision    recall  f1-score   support

       <=50K       1.00      1.00      1.00     29689
        >50K       1.00      1.00      1.00      2872

    accuracy                           1.00     32561
   macro avg       1.00      1.00      1.00     32561
weighted avg       1.00      1.00      1.00     32561

100% 当たっていますね.これは過学習ですが,決定木はロジスティック回帰よりも表現力が高いことのでこうなってしまいます.気になる場合はクロスバリデーションしてください.もしここで精度が低かったらデータを特定のインスタンス周辺に限定するなどして近似精度を担保します.

代理モデルの解釈

決定木から具体的な情報を取り出していきます.とりあえず最初に思いつくのは決定木をプロットしてみることでしょう.やってみると以下のような図が得られます.

feature_names = preprocessor.feature_names
sklearn.tree.plot_tree(explainer, feature_names=feature_names)

output_20_1.png

なんだかよくわからないですね.わたしの経験上,ある程度複雑な問題に対して構築した決定木は全体を見てもよくわからないことが多いです.頻出する特徴量を見ることでどういう特徴量が重視されているかがわかりますが,それなら単に特徴量の重要度を見るだけで済む話です.

決定木で解釈する利点は決定パスを見ることにあると思っています.https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html を参考に,各サンプルに対する決定パスを表示してみます.

X_ = preprocessor.transform(X)
node_indicator = explainer.decision_path(X_)
leaf_id = explainer.apply(X_)

n_nodes = explainer.tree_.node_count
children_left = explainer.tree_.children_left
children_right = explainer.tree_.children_right
feature = explainer.tree_.feature
threshold = explainer.tree_.threshold
label = explainer.tree_.value

num_samples, _ = X_.shape

for _ in range(2):
    sample_id = random.randint(0, num_samples)

    # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                        node_indicator.indptr[sample_id + 1]]

    print(f"Rules used to predict sample {sample_id}")
    for node_id in node_index:
        # continue to the next node if it is a leaf node
        if leaf_id[sample_id] == node_id:
            print(f"leaf node {node_id}; {label[node_id]}")
        else:
            if (X_.iloc[sample_id, feature[node_id]] <= threshold[node_id]):
                threshold_sign = "<="
            else:
                threshold_sign = ">"
            print("decision node {node}; {label} : (data[{sample}, {feature}] = {value}) "
                "{inequality} {threshold})".format(
                    node=node_id,
                    label=label[node_id],
                    sample=sample_id,
                    feature=feature_names[feature[node_id]],
                    value=X_.iloc[sample_id, feature[node_id]],
                    inequality=threshold_sign,
                    threshold=threshold[node_id]))
    print()

Rules used to predict sample 6987
decision node 0; [[29689.  2872.]] : (data[6987, capital-gain] = 0) <= 3897.5)
decision node 1; [[29561.  1058.]] : (data[6987, capital-loss] = 0) <= 1568.5)
decision node 2; [[29131.   149.]] : (data[6987, capital-gain] = 0) <= 2616.0)
decision node 3; [[28820.    55.]] : (data[6987, capital-loss] = 0) <= 1360.0)
decision node 4; [[2.8731e+04 2.0000e+01]] : (data[6987, capital-gain] = 0) <= 2139.5)
decision node 5; [[2.8557e+04 2.0000e+00]] : (data[6987, capital-loss] = 0) <= 927.0)
leaf node 6; [[28537.     0.]]

Rules used to predict sample 16176
decision node 0; [[29689.  2872.]] : (data[16176, capital-gain] = 0) <= 3897.5)
decision node 1; [[29561.  1058.]] : (data[16176, capital-loss] = 1977) > 1568.5)
decision node 97; [[430. 909.]] : (data[16176, fnlwgt] = 202027) <= 208603.5)
decision node 98; [[ 87. 833.]] : (data[16176, fnlwgt] = 202027) > 176860.0)
decision node 134; [[ 67. 157.]] : (data[16176, capital-loss] = 1977) > 1846.0)
decision node 150; [[ 24. 140.]] : (data[16176, hours-per-week] = 50) <= 52.5)
decision node 151; [[  9. 121.]] : (data[16176, age] = 39) <= 55.5)
decision node 152; [[  1. 107.]] : (data[16176, capital-loss] = 1977) > 1881.5)
leaf node 156; [[ 0. 99.]]

得られた結果はサンプル番号 6987, 14741 に対する決定パスです(数回繰り返し実行して面白そうな結果が出たものを選びました).詳しく見てみます.

サンプル 6987 では capital-gain と capital-loss が繰り返し比較されています.このサンプルはこれらの値が小さく,そのようなサンプルは全体として 28537 個あって, そのうち全てが収入 50k 以下となるようです.データ総数から見ても capital-gain <= 2616, capital-loss <= 1360 で条件づけるだけでかなり確度高く 50k 以下を分類できるようです.

サンプル 16176 では 2 段目で上述のサンプル 6987 と分岐します;capital-loss がそれなりに大きいという条件をつけると 60% 程度の人が収入 50k を超えるようです.その後の fnlwgt の 2 回の比較ではあまり絞り込めておらず,最終的には hours-per-week が少ないことおよび age が若いことが決め手になって収入 50k 超過と判定されたようです.

まとめと注意

このように,代理モデルを使った説明はアイデアがわかりやすく実装も簡単という特徴があります.実際 Molner の Interpretable Machine Learning/Global Surrogate にも次の記述があります.

I would argue that the approach is very intuitive and straightforward. This means it is easy to implement, but also easy to explain to people not familiar with data science or machine learning.

ステークホルダーに通じやすいというのはわたしが「説明」に求める最も重要な条件なので,わたしは好んでこの手法を採用しています.もちろん運用上気をつけないといけないことはいっぱいあります.特に大事なのは以下の2つです.

  • 得られた説明は局所的であること.代理モデルは元モデルを局所的に近似したものです.つまりその説明はデータの一部にしか通用しない可能性があります.具体的には,特徴量 A, B があり,元のモデルが A を重要視しているとしても,A があまり変動しない摂動をとってしまうと代理モデルは「B で決定している」と返すことがあります.ここで説明が局所的であることを忘れると A が重要でないような誤解が生じます.データと摂動の範囲をそれなりに広くとると局所的すぎる説明は避けられるのですが,今度は代理モデルが元モデルを近似できなくなって説明の精度が低下します.設定した局所性が実用上十分かどうかは常に検討する必要があります.

  • 得られた説明は代理モデルの説明であること.代理モデルを使って説明できたからといって元のモデルが本当にそういう判断をしている証拠にはなりません.元のモデルの決定はこういう方法でも説明できるくらいに思っておくのがよいです.具体的には,2つの相関する特徴量 A, B があり,元のモデルが A を見て決定していたとしても,代理モデルは「B で決定している」と返すことがあります.これは代理モデルによる手法の本質的な限界です.元モデルの特徴量重要度との関係を見たり,複数の代理モデルを作ると多少緩和されますが,どこかで諦める必要があります.

78
80
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
78
80