16
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ランダムフォレストの全ての決定木をSVGで可視化してみた

Last updated at Posted at 2020-11-05

#概要
dtreevizとsvgutilsを利用してRFの複数の決定木を一つのSVGファイルで出力する方法をまとめてみました。

##RFから任意のモデルを一つ選択して表示
下記のリンクをそのまま流用させて頂きとりあえず実行しました。
[RandomForest も dtreeviz してみる](https://qiita.com/go50/item![レイアウト設計.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/465379/d58f187f-3c25-82d6-e904-47dfc64147af.png)
s/38c7757b444db3867b17)

from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
clf = RandomForestClassifier(n_estimators=100 , max_depth = 2)
clf.fit(iris.data, iris.target)

estimators = clf.estimators_
viz = dtreeviz(
    estimators[0],
    iris.data, 
    iris.target,
    target_name='variety',
    feature_names=iris.feature_names,
    class_names=[str(i) for i in iris.target_names],
) 
viz.view()

Random_forest_1.png

100個の決定木のうち先頭のモデルを可視化ができています。

##全てのモデルをSVGファイルとして保存
上記のプログラムでは一つのモデルで一つのsvgファイルが生成されます。
ループ処理により、RFに含まれるすべての決定木をSVGで出力しました。
(100個すべて表示されるのは面倒なのでviz.save()を使用)
###tqdmのインポート
処理時間を計測するために使用しております。

from tqdm import tqdm

###全モデルの保存

for estimator in tqdm(estimators):
    viz = dtreeviz(
    estimator,
    iris.data, 
    iris.target,
    target_name='variety',
    feature_names=iris.feature_names,
    class_names=[str(i) for i in iris.target_names],
    ) 
    viz.save()

###不具合発生
出力先のTempフォルダを確認すると、決定木モデルのSVGファイルが1つしか保存されていない問題が発生しました。
image.png

どうやら、出力ファイルの命名規則が実行環境のプロセスIDを含んでいるようです。毎回同じファイル名が生成され、その都度SVGファイルが更新されてしまい末尾のモデルしか保存されていないみたいです。
site-packages\dtreeviz\tree.pyの中身

 def save_svg(self):
        """Saves the current object as SVG file in the tmp directory and returns the filename"""
        tmp = tempfile.gettempdir()
        svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
        self.save(svgfilename)
        return svgfilename

##不具合修正
ファイルの命名規則を実行時の時間を用いて生成するように修正しました。
site-packages\dtreeviz\tree.pyにdatatimeをインポート

from datetime import datetime

save_svg()を修正

    def save_svg(self):
        """Saves the current object as SVG file in the tmp directory and returns the filename"""
        tmp = tempfile.gettempdir()
        #svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
        now = datetime.now()
        svgfilename = os.path.join(tmp, f"DTreeViz_{now:%Y%m%d_%H%M%S}.svg")
        self.save(svgfilename)
        return svgfilename

###再度実行
⇒全決定木のモデルをSVG出力に成功
image.png

#全てのsvgファイルを統合
上記のファイルをひとつずつ別々に見ていくには凄く面倒くさい。svgutilsを利用して一つのファイルに統合して出力しました。
(svgutilsの利用の際に参照させていただいたサイトが見当たらない..
再発見しだいリンクを張っておきます。)

決定木の数に応じてなるべく正方形になるように & 決定木の深さを変更してもレイアウトを直ぐに修正できるように設計しました。

レイアウト設計.png

事前に作成した100個のSVGを特定のファイルに保存して下記プログラムを実行

import svgutils.transform as sg
import glob
import math
import os

def join_svg(cell_w, cell_h):
    SVG_file_dir = "./SVG_files"
    svg_filename_list = glob.glob(SVG_file_dir + "/*.svg")

    fig_tmp = sg.SVGFigure("128cm", "108cm")
    N = len(svg_filename_list)
    n_w_cells = int(math.sqrt(N))

    i = 0
    plot_list, txt_list = [], []

    for target_svg_file in svg_filename_list:
        print("i : {}".format(i))
        pla_x = i % n_w_cells
        pla_y = int(i / n_w_cells)
        print("プロット位置:[x,y] : {},{}".format(pla_x, pla_y))
        print(target_svg_file)
        fig_target = sg.fromfile(target_svg_file)
        plot_target = fig_target.getroot()
        plot_target.moveto(cell_w * pla_x, cell_h * pla_y, scale=1)
        print("モデル座標: {},{}".format(cell_w * pla_x, cell_h * pla_y))
        plot_list.append(plot_target)
        txt_target = sg.TextElement(25 + cell_w * pla_x, 20 + cell_h * pla_y,
                                    str(i), size=12, weight="bold")
        print("テキスト座標: {},{}".format(25 + cell_w * pla_x, 20 + cell_h * pla_y))
        txt_list.append(txt_target)
        print(i)
        i += 1

    fig_tmp.append(plot_list)
    fig_tmp.append(txt_list)

    ouput_dir = SVG_file_dir + "/output"

    try :
        fig_tmp.save(ouput_dir + "/RF.svg")

    except FileNotFoundError:
        os.mkdir(ouput_dir)
        fig_tmp.save(ouput_dir + "/RF.svg")
 
join_svg(400, 300)

##出力結果
無事すべてのファイル結合に成功しました。

RF_output_svg.png

予想以上にファイルの容量が大きくなりました(10M程度)。
chromeでも表示するのに時間を要します。他のアプリを起動したままだと表示するのにメモリが足らずエラーが出るケースもあります。

16
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
28

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?