Jupyterで、dtreeviz(決定木の結果を可視化するライブラリ)を実行すると、表示がおかしくなってしまった時の解決の一例を紹介します。
このような事象です。
例えば、タイタニックデータを分析しようとして、決定木の結果をdtreevizを用いて、Jupyter上で可視化しようとします。
import pandas as pd
df = pd.read_csv('train.csv') # Kaggleでダウンロード
label = df["Survived"].tolist()
feature = pd.get_dummies(df[df.columns[df.columns != 'Survived']])
feature = feature.fillna(0)
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(feature, label)
from dtreeviz.trees import dtreeviz # pipでインストールする必要あり
import numpy as np
viz = dtreeviz(clf, feature, np.array(label), target_name='status', feature_names=list(feature.columns), class_names=["Dead", "Survived"])
display(viz) # 可視化用
このスクリプトを実行すると、以下のようになります。
このように、表示がおかしくなることがあります。
解決方法
色々試してみましたが、一番シンプルで楽なのは、以下の解決策でした。
import pandas as pd
df = pd.read_csv('train.csv') # Kaggleでダウンロード
label = df["Survived"].tolist()
feature = pd.get_dummies(df[df.columns[df.columns != 'Survived']])
feature = feature.fillna(0)
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_depth=3)
clf = clf.fit(feature, label)
from dtreeviz.trees import dtreeviz # pipでインストールする必要あり
import numpy as np
viz = dtreeviz(clf, feature, np.array(label), target_name='status', feature_names=list(feature.columns), class_names=["Dead", "Survived"])
viz.view() # 修正:SVGで出力
この修正を加えると、別タブでSVGファイルが出力されます。
notebook上に図は残せないですが、個人的には、この解決策で特に問題なかったので採用しました。
最後に
私自身、些細な事で、つまずくケースは結構多いので、今回の件も情報共有しました。
dtreevizを利用する際の参考にして頂けると嬉しいです。