LoginSignup
4
7

More than 3 years have passed since last update.

scikit learnから解釈過程を表示するWebアプリを作ってみた

Last updated at Posted at 2019-06-15

機械学習がどうアプリに応用されていくかという事については、画像認識・音声認識等々の機能が、家電のインターフェースや業務アプリの識別機能としてモジュール化してアプリの一部になるのだろうと思っています。

例えば今関わっている開発だと、ITの契約書を登録する仕組みで、今は契約内容や期間や金額などからif-elseで契約が正しいかチェックしていますが、契約書と良い/悪いを読み込ませて機械学習させたモデルを用いてチェックしてくれるなんて事が可能になるのかな??(データ集めが難しそう。)

機械学習のチェック内容を判断するのは人間でしょうから、判断の過程を表示する事が重要になってくると思います。そこで今回は「scikit learnから解釈過程を表示するWebアプリ」に取組んでみました。

※8/17こちらにDjango版を掲載しました。
※07/02 15:00 Flask環境からDjango環境へ変更しました。おいおい変更点やソースの情報を掲載します。
※06/19 13:00 Understanding the decision tree structureから得られる情報を追加しました。
掲載ソースも入替ました。

Demo※タイムアウトした場合もう一度画面更新してみて下さい。(FREE版なので都度起動しているので重たいようです。)
gamen.jpg

Django環境へ変更

五月連休中にDjangoにトライしてみましたが、今後業務アプリとして作り込むならフルスタックのフレームワークを利用すべきだろうと考えDjango環境へ変更しました。
・DjangoとFlaskを比較した上でのDjangoのメリットは、urls.pyがルーティングを司るので、プログラム事にファイルを分ける事ができる事につきます。FlaskのBlueprintも試しましたがちょっと直観的ではなく、指定が悪いのかajaxを動かす事ができませんでした。

可視化する為のライブラリ

可視化にあたっては、以下のライブラリを試しました。

1.dtreeviz:
これが一番綺麗でわかりやすく、入力の推論の過程と根拠となる特徴量を表示するのでベストなのですが、loaclのflaskサーバでは動作しましたが、AzureWebappsだとdtreevizからIPython.utils.ioを読んだ時にエラーになりどうにも解決しませんでした。AzureWebappsでIOでサポートされていない機能があるようです。

2.graphviz:
loaclのflaskサーバではgraphviz-2.38.zipを展開してパスを切るだけで動作しました。AzureWebappsではローカルに展開ファイルをFTPでアップして、site配下にパス設定したapplicationHost.xdtを起きましたが、処理されているのか/いないのかタイムアウトしてしまいます。

3.dtreeplt:
@nekoumeiさん作成のdtreepltは、numpyとmatplotlib・scikit-learnのみで決定木を可視化するだけあってAzureWebappsでも動作しました。

4.追加の情報としてUnderstanding the decision tree structureのソースを利用しました。情報として以下をテキストボックスに表示します。
・画面入力値
・画面入力に対する予測値
・画面入力が到達する葉のID
・画面入力入力の到達する過程がスパース行列で得られるのでそれを表示
・plot_unveil_tree_structureで木構造を表示。特徴名・ターゲット名を表示するように修正。
また画面入力が到達する葉のIDの箇所がわかるように「<--------[[[[result]]]」をマーク。
・plot_unveil_tree_structureで画面入力の予測の過程を表示。これも特徴名・ターゲット名を
表示するように修正。

手順

前回のJupyter Notebookで訓練したモデルをAzure App Service上で動かしてみるの環境にライブラリを追加しています。
【再掲】
1.Jupyter Notebookでモデルを作成し、pickleを用いてsavファイルを吐き出す。
2.AzureのWebappsにPythonサーバを作成。
3.webappsダッシュボードの「開発ツール」→「拡張機能」→「追加」でPython 3.6.4 x64を追加。
※注意:モデルを作成する環境とpythonのバージョン(3.6*レベル)とbit数は合せて下さい。Scikits-Learn RandomForrest trained on 64bit python wont open on 32bit pythonにあるようにbit数が合わないとロード時にエラーになり、私はだいぶロスりました。
4.webappsダッシュボードの「コンソール」 からチェンジディレクトリーで移動し、ライブラリを追加。

python -m pip install --upgrade pip
pip install scikit-learn
 WARNING: The script f2py.exe is installed ・・」と表示されますがインストールは完了しています
pip install Flask
pip install pandas
pip install matplotlib
追加-start
pip install numpy
pip install dtreeplt
追加-end

pickleは標準ライブラリに含まれているのでインストールする必要はありません。
5.Jupyter Notebookで作成したアプリケーションをベースにWebAppsアプリを作成する。
6.1で作成したsavファイルをftpで所定の位置にアップする。

といった手順で行いました。

サーバサイドアプリ

main.py
# -*- coding: utf-8 -*-
from common import *

application = Flask(__name__)
application.config.from_object(__name__)

##################################################################
# 初期メニュー
##################################################################
@application.route('/')
def index():
    return render_template('index.html')

##################################################################
# Adult Census Income Binary Classification
##################################################################

#読み込むモデルをグローバル宣言
#fn_DecisionTree = './Data/AdultCensusIncome_model(DecisionTree).sav' 
fn_DecisionTree = './Data/AdultCensusIncome_model(WEBAPPS-depth-7).sav'
fn_RandomForest = './Data/AdultCensusIncome_model(RandomForest).sav'       

def IncomeConvert(df):

    try:
        df['workclass'] = df['workclass'].map( {
        '?': 99,
        'Federal-gov':0,
        'Local-gov':1,
        'Never-worked':2,
        'Private':3,
        'Self-emp-inc':4,
        'Self-emp-not-inc':5,
        'State-gov':6,
        'Without-pay':7,
        } ).astype(int)
    except:
        log('erro1')

    try:
        df['occupation'] = df['occupation'].map( {
        '?': 99,
        'Adm-clerical':0,
        'Armed-Forces':1,
        'Craft-repair':2,
        'Exec-managerial':3,
        'Farming-fishing':4,
        'Handlers-cleaners':5,
        'Machine-op-inspct':6,
        'Other-service':7,
        'Priv-house-serv':8,
        'Prof-specialty':9,
        'Protective-serv':10,
        'Sales':11,
        'Tech-support':12,
        'Transport-moving':13,      
        } ).astype(int)
    except:
        log('erro2')  

    try:
        df['race'] = df['race'].map( {
        'White':0,
        'Amer-Indian-Eskimo':1,
        'Asian-Pac-Islander':2,
        'Black':4,
        'Other':5,
        } ).astype(int)
    except:
        log('erro3')  

    try:
        df['sex'] = df['sex'].map( {
            'Male': 0,
            'Female': 1,
            'Other': 2
        } ).astype(int)
    except:
        log('erro4')  

    try:
        df['native-country'] = df['native-country'].map( {
            '?': 99,
            'Cambodia':0,
            'Canada':1,
            'China':2,
            'Columbia':3,
            'Cuba':4,
            'Dominican-Republic':5,
            'Ecuador':6,
            'El-Salvador':7,
            'England':8,
            'France':9,
            'Germany':10,
            'Greece':11,
            'Guatemala':12,
            'Haiti':13,
            'Holand-Netherlands':14,
            'Honduras':15,
            'Hong':16,
            'Hungary':17,
            'India':18,
            'Iran':19,
            'Ireland':20,
            'Italy':21,
            'Jamaica':22,
            'Japan':23,
            'Laos':24,
            'Mexico':25,
            'Nicaragua':26,
            'Outlying-US(Guam-USVI-etc)':27,
            'Peru':28,
            'Philippines':29,
            'Poland':30,
            'Portugal':31,
            'Puerto-Rico':32,
            'Scotland':33,
            'South':34,
            'Taiwan':35,
            'Thailand':36,
            'Trinadad&Tobago':37,
            'United-States':38,
            'Vietnam':39,
            'Yugoslavia':40,
        } ).astype(int)
    except:
        log('erro5')

    try:
        df['income'] = df['income'].map( {
            '<=50K': 0,
            '>50K': 1,
        } ).astype(int)
    except:
        print('erro6') 

    return df

#**************************************
# TREEの各情報取得 https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
#**************************************
def decisionTreeStructure(model,train_x,feature_names,target_names):

    text = ""
    text = text + '【InputData】\n' + str(train_x.loc[0]) + '\n'

    #入力の予測
    result = model.predict(train_x)
    if result[0] == 0:
        result = '<=50k' 
    else:
        result = '>50k' 
    text = text + "【Result】 = "  + result + "\n"

    #入力の到達するIDを求める
    leave_id = model.apply(train_x)
    leave_id = leave_id[0]                            #到達するID
    text = text + "【Leave_id】 = "  + str(leave_id) + "\n"

    #入力の到達する過程がスパース行列で得られる
    decision_path = model.decision_path(train_x)  
    node_indicator = str(decision_path) 
    text = text + "【Decision Process】 \n"  +  node_indicator + "\n"

    try:    
        #tree_のプロパティを得る
        n_nodes = model.tree_.node_count
        children_left = model.tree_.children_left   #左の子ノードID(子が無い場合は「-1」がセットされている) 
        children_right = model.tree_.children_right #右の子ノードID
        feature = model.tree_.feature               #特徴量のインデックス
        threshold = model.tree_.threshold           #閾値

        #各ノードの最多数クラス  https://own-search-and-study.xyz/2016/12/25/scikit-learn%e3%81%a7%e5%ad%a6%e7%bf%92%e3%81%97%e3%81%9f%e6%b1%ba%e5%ae%9a%e6%9c%a8%e6%a7%8b%e9%80%a0%e3%81%ae%e5%8f%96%e5%be%97%e6%96%b9%e6%b3%95%e3%81%be%e3%81%a8%e3%82%81/
        nodeClasses = np.argmax(model.tree_.value.T, axis=0)

        #################################################
        # 木構造をたどり各ノードの深さや葉であるかなどのプロパティを計算し木構造を表示する
        #################################################
        node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
        is_leaves = np.zeros(shape=n_nodes, dtype=bool)
        stack = [(0, -1)]  # seed is the root node id and its parent depth

        #葉かどうか識別
        while len(stack) > 0:
            node_id, parent_depth = stack.pop()
            node_depth[node_id] = parent_depth + 1

            #子のノードが左右同じ(-1)だったら葉としてis_leavesにTrueをセット 
            if (children_left[node_id] != children_right[node_id]):
                stack.append((children_left[node_id], parent_depth + 1))
                stack.append((children_right[node_id], parent_depth + 1))
            else:
                is_leaves[node_id] = True

        text = text +  "【Binary tree structure has】= " + str(n_nodes) + " \n【Decision Tree Structure】\n"
        #表示
        for i in range(n_nodes):
            if is_leaves[i]:
                text = text +  (node_depth[i] * "\t") + "|-leaf  id=" + str(i) +  "  class= 【"+ target_names[nodeClasses[0][i]] + "】"
                if i == leave_id:
                    text = text +  "<---------------------------------[[[[*result*]]]"

                text = text  +"\n"
            else:
                text = text +  (node_depth[i] * "\t") + "|-node  id=" + str(i)
                text = text +  " if [" + feature_names[feature[i]] + "] <= [" + str(threshold[i])+  "] then node "  
                text = text +  str(children_left[i]) + " else to node " + str(children_right[i]) + "】\n"

        #################################################
        #識別過程を表示する
        #################################################
        text = text +  "【Decision Process】\n"
        node_index = decision_path.indices[decision_path.indptr[0]:
                                       decision_path.indptr[1]]
        for node_id in node_index:
            if leave_id == node_id:
                continue

            w_train_x = str(train_x.loc[0][feature[node_id]])  #型がint・float・str混在して下記のif判定でエラーになるので強引に文字型にしている。
            w_threshold = str(threshold[node_id])

            #if (train_x.loc[0][feature[node_id]] <= threshold[node_id]):
            if (w_train_x <= w_threshold):
                threshold_sign = "<="
                next_node = str(children_left[node_id])
            else:
                threshold_sign = ">"
                next_node = str(children_right[node_id])

            text = text + "decision id= " + str(node_id) + "【"
            text = text + str(feature_names[feature[node_id]]) + "】"
            text = text + w_train_x + " "
            text = text + threshold_sign + " "
            text = text + w_threshold
            text = text + " goto " + next_node +  "  class= 【"+ target_names[nodeClasses[0][i]] +"】\n"
        return text
    except:
        return "!!decisionTreeStructureError!!"

#**************************************
# Adult Census Income Binary Classification dataset
#**************************************
#html表示
@application.route('/Income', methods=['POST'])
def Income():

    #incomeの特徴名をグローバル宣言
    feature_names=['age','workclass','education-num',
                   'occupation','race','sex','capital-gain',
                   'capital-loss','hours-per-week','native-country']
    target_names=['<=50k', '>50k']

    status = ''

    try:
        dict_request = json.loads(request.form['json_data'])
        df = pd.io.json.json_normalize(dict_request)   
        df = df.drop(['alg'],axis=1)           
        train_x = IncomeConvert(df)
        #feature_name = df.columns.values 
    except:
        return render_template('Income.html',status = 'No Input(Input Error)' )    

    #feature_name = df.columns.values 
    # 保存したモデルをロードする
    try:
        if dict_request['alg'] == "1":
            filename = fn_DecisionTree
        else:
            filename = fn_RandomForest

        model = pickle.load(open(filename, mode='rb'))


    except:
        return render_template('Income.html', status = 'Load Error', text=text, file_name="")

    #Understanding the decision tree structure
    text = decisionTreeStructure(model,train_x,feature_names,target_names)

    #dtreeplt視覚化
    try: 
        dtree = dtreeplt(
            model=model,
            feature_names=feature_names ,
            target_names=target_names,
        )
        fig = dtree.view()
        fig.savefig('./static/Income-tree.png') 

    except:
        return render_template('Income.html', status = 'dtreeplt Error', text=text, file_name="Income-tree.png")

    return render_template('Income.html', status = 'Ok', text=text, file_name="Income-tree.png")

#**************************************
#  ajax送信スクリプト
#**************************************
@application.route('/IncomeCal', methods=['POST'])
def IncomeCal():
    try:
        df = pd.io.json.json_normalize(request.json)

        if request.json['alg'] == "1":
            filename = fn_DecisionTree
        else:
            filename = fn_RandomForest   

        df = df.drop(['alg'],axis=1)           
        train_x = IncomeConvert(df)
    except:
        result = "ajaxSendError"  


    # 保存したモデルをロードする
    try:
        model = pickle.load(open(filename, mode='rb'))
        result = model.predict(train_x)

        if result[0] == 0:
            result = '<=50k' 
        else:
            result = '>50k' 
        leave_id = model.apply(train_x)
        node_indicator = model.decision_path(train_x)
        result = result + '  leave_id = ' + str(leave_id[0])+ '\n  node_indicator = \n' + str(node_indicator[0])

    except:
        result = "Load Error"

    return result

#************************************** 
#  サーバ起動  マルチスレッド指定 デフォルトはTrueの動きをするようだが。 https://qiita.com/5zm/items/251be97d2800bf67b1c6
#************************************** 
if __name__ == '__main__':
    application.debug = True # デバッグ
    application.run(host='0.0.0.0', port=8000, threaded=True)
クライアントソースIncome.html
<!--*****************************************************************
*  azureMLサンプル:米国国勢調査提供の収入のサンプルの訓練結果にrestで接続
*  
*  2019/03/24
*****************************************************************-->
{% extends "base.html" %}
{% block body %}

<form name='sendform' action="{{ url_for('Income') }}" style="display: inline" method="post">
<div class="card" >
<div class="card-header" style="height:50px; font-size:1.5rem; ">
    <b>scikit-learn  webアプリサンプル</b>
</div>
<div class="card-body">
    <div class="row">
        <div class="col-md-2 mb-1">年齢</div>      <!--col-md-1:mdはPCのMIDDLEサイズ。mb-1はマージン(空白)をボトム(下)に設定-->
        <div class="col-md-2 mb-1" >
            <select class="form-control input-lg mb-1" id="age">
                <option>20</option>
                <option>25</option>
                <option>30</option>
                <option>35</option>
                <option>40</option>
                <option>45</option>             
                <option>50</option>
                <option>55</option>             
                <option>60</option>
                <option>65</option>             
                <option>70</option>
                <option>75</option>             
                <option>80</option>
            </select>
        </div>
        <div class="col-md-2 mb-1">ワーククラス</div>
        <div class="col-md-2 mb-1" >
            <select class="form-control input-lg mb-1" id="workclass">
                <option value="Federal-gov">Federal-gov</option>
                <option value="Local-gov">Local-gov</option>
                <option value="Never-worked">Never-worked</option>
                <option value="Private">Private</option>
                <option value="Self-emp-inc">Self-emp-inc</option>
                <option value="Self-emp-not-inc">Self-emp-not-inc</option>
                <option value="State-gov">State-gov</option>
                <option value="Without-pay">Without-pay</option>
                <option value="Without-pay">Without-pay</option>
            </select>
        </div>

        <div class="col-md-1 mb-1">学校</div>
        <div class="col-md-3 mb-1" >
            <select class="form-control input-lg mb-1" id="education-num">
                <option value="1">Preschool</option>
                <option value="2">1st-4th</option>
                <option value="3">5th-6th</option>
                <option value="4">7th-8th</option>
                <option value="5">9th</option>
                <option value="6">10th</option>
                <option value="7">11th</option>
                <option value="8">12th</option>
                <option value="9">HS-grad</option>
                <option value="10">Some-college</option>
                <option value="11">Assoc-voc</option>
                <option value="12">Assoc-acdm</option>
                <option value="13">Bachelors</option>
                <option value="14">Masters</option>
                <option value="15">Prof-school</option>
                <option value="16">Doctorate</option>
            </select>
        </div>

    </div>

    <div class="row">

        <div class="col-md-1 mb-1">職業</div>
        <div class="col-md-3 mb-1" >
            <select class="form-control input-lg mb-1" id="occupation">
                <option value="Adm-clerical">Adm-clerical</option>
                <option value="Armed-Forces">Armed-Forces</option>
                <option value="Craft-repair">Craft-repair</option>
                <option value="Exec-managerial">Exec-managerial</option>
                <option value="Farming-fishing">Farming-fishing</option>
                <option value="Handlers-cleaners">Handlers-cleaners</option>
                <option value="Machine-op-inspct">Machine-op-inspct</option>
                <option value="Other-service">Other-service</option>
                <option value="Priv-house-serv">Priv-house-serv</option>
                <option value="Prof-specialty">Prof-specialty</option>
                <option value="Protective-serv">Protective-serv</option>
                <option value="Sales">Sales</option>
                <option value="Tech-support">Tech-support</option>
                <option value="Transport-moving">Transport-moving</option>
            </select>
        </div>

        <div class="col-md-2 mb-1">人種</div>
        <div class="col-md-2 mb-1" >
            <select class="form-control input-lg mb-1" id="race">
                <option value="Amer-Indian-Eskimo">Amer-Indian-Eskimo</option>
                <option value="Asian-Pac-Islander">Asian-Pac-Islander</option>
                <option value="Black">Black</option>
                <option value="Other">Other</option>
                <option value="White">White</option>
            </select>
        </div>

        <div class="col-md-2 mb-1">性別</div>
        <div class="col-md-2 mb-1" >
            <select class="form-control input-lg mb-1" id="sex">
                <option value="Male">Male</option>
                <option value="Female">Female</option>
            </select>
        </div>  

    </div>

    <div class="row">


        <div class="col-md-2 mb-1">キャピタルゲイン</div>
        <div class="col-md-2 mb-1" >
            <input  type="text" class="form-control input-lg" id="capital-gain" value=0 >
        </div>

        <div class="col-md-2 mb-1">キャピタルロス</div>
        <div class="col-md-2 mb-1" >
            <input  type="text" class="form-control input-lg" id="capital-loss" value=0 >
        </div>

        <div class="col-md-2 mb-1">労働時間</div>
        <div class="col-md-2 mb-1" >
            <input  type="text" class="form-control input-lg" id="hours-per-week" value=40 >
        </div>

    </div>

    <div class="row">

        <div class="col-md-1 mb-1"></div>
        <div class="col-md-3 mb-1" >
            <select class="form-control input-lg mb-1" id="native-country">
                <option value="Cambodia">Cambodia</option>
                <option value="Canada">Canada</option>
                <option value="China">China</option>
                <option value="Columbia">Columbia</option>
                <option value="Cuba">Cuba</option>
                <option value="Dominican-Republic">Dominican-Republic</option>
                <option value="Ecuador">Ecuador</option>
                <option value="El-Salvador">El-Salvador</option>
                <option value="England">England</option>
                <option value="France">France</option>
                <option value="Germany">Germany</option>
                <option value="Japan">Japan</option>
                <option value="Laos">Laos</option>
                <option value="Mexico">Mexico</option>
                <option value="United-States">United-States</option>
                <option value="Vietnam">Vietnam</option>
                <option value="Yugoslavia">Yugoslavia</option>
            </select>
        </div>

        <div class="col-md-2 mb-1">アルゴリズム</div>
        <div class="col-md-4 mb-1" >
            <select class="form-control input-lg mb-1" id="alg">
                <option value="1">DecisionTree_model</option>
                <option value="2">RandomForest_model</option>
            </select>
        </div>  

    </div>

    <div class="row">
        <div class="col-md-2 mb-1" >
            <button class="btn btn-default" id="button1" style="width: 100px; padding: 5px;">Ajax Call</button>
        </div>
        <div class="col-md-3 mb-1" >
            <button class="btn btn-default" id="button2" style="width: 100px; padding: 5px;">Reload</button>
        </div>      
    </div>
    <div class="col-md-5 mb-1" >
            <h2>status→{{ status }}</h2>
            <textarea rows="40" cols="320">{{text}}</textarea><br>
    </div>


</div>  <!--class="card-body"の終端-->
</div>  <!--class="card"の終端-->

<div class="modal-body" style="width: 100%; overflow-x: auto;">
    <!-- モーダルは bodyからスクロールを削除して、代わりにモーダルコンテンツをスクロールする。-->
    <!-- img src="{{ url_for('static', filename = 'incom_DecisionTree.svg') }}" -->

    <img src="{{ url_for('static', filename = file_name) }}" width="4000px" height="400px"  > 
</div>

<input type="hidden" name="json_data" value="">

</form>


<script type="text/javascript">
//*************************************************
//*  form送信 
//*************************************************
$("#button2").click(function() {

    document.forms.sendform.json_data.value = JSON.stringify({
    "age":$("#age").val(),
    "workclass":$("#workclass").val(),
    "education-num":$("#education-num").val(),
    "occupation":$("#occupation").val(),
    "race":$("#race").val(),
    "sex":$("#sex").val(),
    "capital-gain":$("#capital-gain").val(),
    "capital-loss":$("#capital-loss").val(),
    "hours-per-week":$("#hours-per-week").val(),
    "native-country":$("#native-country").val(),
    "alg":$("#alg").val()
    });

    document.forms.submit();

});
//*************************************************
//*  ajax送信スクリプト
//*************************************************
"use strict";

var send_data = {};

$("#button1").click(function() {

    send_data = JSON.stringify({
    "age":$("#age").val(),
    "workclass":$("#workclass").val(),
    "education-num":$("#education-num").val(),
    "occupation":$("#occupation").val(),
    "race":$("#race").val(),
    "sex":$("#sex").val(),
    "capital-gain":$("#capital-gain").val(),
    "capital-loss":$("#capital-loss").val(),
    "hours-per-week":$("#hours-per-week").val(),
    "native-country":$("#native-country").val(),
    "alg":$("#alg").val()
    });

    $.ajax({
        type:'POST',
        url:'/IncomeCal',
        data:send_data,
        contentType:'application/json',
        success:function(data) {
            alert("結果 = " + data );},
        error: function(data) {
            alert('error!!!' + JSON.stringify(data) );
        }
    });
    return false;
});


</script>

{% endblock %}

web.config、base.html、common.pyはこちらをご覧ください。

だいぶアプリっぽくなりました。もう少し注意点など記述予定です。

plot_unveil_tree_structure備忘的まとめ。

名称 内容 入力 戻り値 使用例
predictメソッド 予測値を求める チェックしたい入力(配列またはpandasのデータフレーム形式) 予測値を示すインデックス result = model.predict(train_x)
if result[0] == 0:
result = "setosa(0)"
elif result[0] == 1:
result = "versicolor(1)"
else:
result = "virginica(2)"
applyメソッド 到達するIDを求める チェックしたい入力(配列またはpandasのデータフレーム形式) 到達するノードID leave_id = model.apply(train_x)
leave_id = leave_id[0]
decision_pathメソッド 識別する過程を求める チェックしたい入力(配列またはpandasのデータフレーム形式) スパース行列が戻る decision_path = model.decision_path(train_x) 入力i番目のデータの通ったnodeはdecision_path.indices[decision_path.indptr[i]: decision_path.indptr[i+1]]で得られる。
tree_オブジェクトのnode_countプロパティ 学習したモデルのノード数を求める ノード数 n_nodes = model.tree_.node_count
tree_オブジェクトのchildren_left/rightプロパティ 学習したモデルのあるノードIDの子のノードIDを求める ノードの配列。子が無い場合は「-1」。 children_left = model.tree_.children_left print(children_left[0])
tree_オブジェクトのfeatureプロパティ 学習したモデルのあるノードIDの特徴量を示すインデックスを求める。そのインデックスが示す特徴名はirisならiris.feature_namesにセットされている。 ノードID 特徴量を示すインデックス
tree_オブジェクトのthresholdプロパティ 学習したモデルのあるノードIDの閾値。「<=」閾値基準。 ノードID 閾値
tree_オブジェクトのvalueプロパティ 各ノードのクラス別到達データ数 ノードID np.argmax(clf.tree_.value.T, axis=0)で各ノードの最多数クラスが得られます。 こちら参照しています。
4
7
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7