#概要
dtreevizとsvgutilsを利用してRFの複数の決定木を一つのSVGファイルで出力する方法をまとめてみました。
##RFから任意のモデルを一つ選択して表示
下記のリンクをそのまま流用させて頂きとりあえず実行しました。
[RandomForest も dtreeviz してみる](https://qiita.com/go50/item![レイアウト設計.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/465379/d58f187f-3c25-82d6-e904-47dfc64147af.png)
s/38c7757b444db3867b17)
from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
clf = RandomForestClassifier(n_estimators=100 , max_depth = 2)
clf.fit(iris.data, iris.target)
estimators = clf.estimators_
viz = dtreeviz(
estimators[0],
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.view()
100個の決定木のうち先頭のモデルを可視化ができています。
##全てのモデルをSVGファイルとして保存
上記のプログラムでは一つのモデルで一つのsvgファイルが生成されます。
ループ処理により、RFに含まれるすべての決定木をSVGで出力しました。
(100個すべて表示されるのは面倒なのでviz.save()を使用)
###tqdmのインポート
処理時間を計測するために使用しております。
from tqdm import tqdm
###全モデルの保存
for estimator in tqdm(estimators):
viz = dtreeviz(
estimator,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.save()
###不具合発生
出力先のTempフォルダを確認すると、決定木モデルのSVGファイルが1つしか保存されていない問題が発生しました。
どうやら、出力ファイルの命名規則が実行環境のプロセスIDを含んでいるようです。毎回同じファイル名が生成され、その都度SVGファイルが更新されてしまい末尾のモデルしか保存されていないみたいです。
site-packages\dtreeviz\tree.pyの中身
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
self.save(svgfilename)
return svgfilename
##不具合修正
ファイルの命名規則を実行時の時間を用いて生成するように修正しました。
site-packages\dtreeviz\tree.pyにdatatimeをインポート
from datetime import datetime
save_svg()を修正
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
#svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
now = datetime.now()
svgfilename = os.path.join(tmp, f"DTreeViz_{now:%Y%m%d_%H%M%S}.svg")
self.save(svgfilename)
return svgfilename
#全てのsvgファイルを統合
上記のファイルをひとつずつ別々に見ていくには凄く面倒くさい。svgutilsを利用して一つのファイルに統合して出力しました。
(svgutilsの利用の際に参照させていただいたサイトが見当たらない..
再発見しだいリンクを張っておきます。)
決定木の数に応じてなるべく正方形になるように & 決定木の深さを変更してもレイアウトを直ぐに修正できるように設計しました。
事前に作成した100個のSVGを特定のファイルに保存して下記プログラムを実行
import svgutils.transform as sg
import glob
import math
import os
def join_svg(cell_w, cell_h):
SVG_file_dir = "./SVG_files"
svg_filename_list = glob.glob(SVG_file_dir + "/*.svg")
fig_tmp = sg.SVGFigure("128cm", "108cm")
N = len(svg_filename_list)
n_w_cells = int(math.sqrt(N))
i = 0
plot_list, txt_list = [], []
for target_svg_file in svg_filename_list:
print("i : {}".format(i))
pla_x = i % n_w_cells
pla_y = int(i / n_w_cells)
print("プロット位置:[x,y] : {},{}".format(pla_x, pla_y))
print(target_svg_file)
fig_target = sg.fromfile(target_svg_file)
plot_target = fig_target.getroot()
plot_target.moveto(cell_w * pla_x, cell_h * pla_y, scale=1)
print("モデル座標: {},{}".format(cell_w * pla_x, cell_h * pla_y))
plot_list.append(plot_target)
txt_target = sg.TextElement(25 + cell_w * pla_x, 20 + cell_h * pla_y,
str(i), size=12, weight="bold")
print("テキスト座標: {},{}".format(25 + cell_w * pla_x, 20 + cell_h * pla_y))
txt_list.append(txt_target)
print(i)
i += 1
fig_tmp.append(plot_list)
fig_tmp.append(txt_list)
ouput_dir = SVG_file_dir + "/output"
try :
fig_tmp.save(ouput_dir + "/RF.svg")
except FileNotFoundError:
os.mkdir(ouput_dir)
fig_tmp.save(ouput_dir + "/RF.svg")
join_svg(400, 300)
##出力結果
無事すべてのファイル結合に成功しました。
予想以上にファイルの容量が大きくなりました(10M程度)。
chromeでも表示するのに時間を要します。他のアプリを起動したままだと表示するのにメモリが足らずエラーが出るケースもあります。