はじめに
よく因果推論を行って、その結果をsemopyのplotに通して可視化させているんですが、プレゼンなどで見せるときに、どこを見ればよいのかわかりにくくなるんですよね。
今までは枠で囲ってみたりしていたんですが、それだとめんどくさいし、余計な構造まで示してしまいます。
エッジとノードに色を付けられると、見やすいと思っていたのでそれを作ってみます。
とりあえずイメージしやすいように、例を以下に示します。
これ↓が
こうなります↓
テスト用因果構造を推定
普段、SAMを利用しているのでSAMを使ってbreast_cancerデータを因果推論してみます。
from sklearn.datasets import load_breast_cancer
from cdt.causality.graph import SAM
import pandas as pd
import semopy as sm
import networkx as nx
# データセットの読み込み
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
# semopyの受け付ける形にカラム名を変換
# 間が空白だとエラーでるので、"_"にします
_columns = [col.replace(" ", "_") for col in X.columns]
X.columns = _columns
# パラメータは環境に合わせて設定
sam_matrix = SAM(gpus=, nruns=, njobs=).predict(X)
# 隣接行列を作成し、2値化させ、構造方程式を定義
adjacency_matrix = pd.DataFrame(nx.to_numpy_array(sam_matrix), columns=X.columns, index=X.columns)
adjacency_matrix_01 = adjacency_matrix.applymap(lambda x: 1 if x > 0.9 else 0)
adjacency_matrix_01
equations = []
for row in adjacency_matrix_01.index:
equation = f"{row} ~ "
parents = adjacency_matrix_01.columns[adjacency_matrix_01.loc[row] == 1]
equation += " + ".join(parents)
if parents.size != 0:
equations.append(equation)
# semopyのモデルの作成
mod = sm.Model("\n".join(equations))
res = mod.fit(X)
# 通常のプロット
sm.semplot(mod, "breast_canser.png")
参考用なので構造を複雑化しないように2値化の際に、閾値をかなり上げています。
この部分のみを注目していきましょう。
実際に改造してみる
以下のコマンドでsite-packagesのディレクトリパスを確認します。
pip show semopy
site-packages以下のsemopy/plot.pyを開くと、実際に動いているプログラムを表示できます。
そこの110~120行目くらいに観測変数のノードを追加する部分がありますので、そこで色付けを行いましょう。
まず、以下の形式でnode_colorsという引数にて色付け指示を受け取るように決めます。
{
"ノード名1": "色1",
"ノード名2": "色2"
}
そのうえで以下の処理を追加します。
if obs in node_colors.keys():
g.node(obs, label=obs, color=node_colors[obs], style="bold")
else:
g.node(obs, label=obs, style="bold")
これでノードへの色付けができました。
つぎにエッジへの色付けをします。これは新しく引数を作る必要はなく、node_colorsを用いて実装します。
if ((rval in node_colors.keys()) and lval in (node_colors.keys())) and \
(node_colors[rval] == node_colors[lval]):
g.edge(rval, lval, label=label, color=node_colors[rval],
arrowsize=str(arrowsize), style="bold")
else:
g.edge(rval, lval, label=label,
arrowsize=str(arrowsize), style="bold")
処理の内容としては以下の条件を満たす場合にノードを同じ色を設定しています。
- 指定されたノードとの間にエッジが存在する
- ノード同士で同じ色を指定している
これで色付け部分の実装は終わりです。実際にnode_colorsを指定して呼び出すと鹿のようになります。
sm.semplot(mod, "breast_canser.png", node_colors={"worst_fractal_dimension": "red",
"worst_compactness": "red"})
ついでに矢印の大きさも変えてみる
気づいた人もいると思いますが、先ほどのプログラムでarrow_sizeという引数も追加しています。
ここの値を設定すると、グラフ内の矢印の大きさも変更できます。
sm.semplot(mod, "breast_canser.png", node_colors={"worst_fractal_dimension": "red",
"worst_compactness": "red"},
arrowsize=2)
ただ、大きさは矢印の頭が大きくなるだけですので、逆に見づらいかもしれません。





