0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

semopyのplotを魔改造してエッジとノードに色付けできるようにする

0
Posted at

はじめに

よく因果推論を行って、その結果をsemopyのplotに通して可視化させているんですが、プレゼンなどで見せるときに、どこを見ればよいのかわかりにくくなるんですよね。
今までは枠で囲ってみたりしていたんですが、それだとめんどくさいし、余計な構造まで示してしまいます。

エッジとノードに色を付けられると、見やすいと思っていたのでそれを作ってみます。

とりあえずイメージしやすいように、例を以下に示します。

これ↓が

image.png

こうなります↓

image.png

テスト用因果構造を推定

普段、SAMを利用しているのでSAMを使ってbreast_cancerデータを因果推論してみます。

from sklearn.datasets import load_breast_cancer
from cdt.causality.graph import SAM
import pandas as pd
import semopy as sm
import networkx as nx

# データセットの読み込み
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)

# semopyの受け付ける形にカラム名を変換
# 間が空白だとエラーでるので、"_"にします
_columns = [col.replace(" ", "_") for col in X.columns]
X.columns = _columns

# パラメータは環境に合わせて設定
sam_matrix = SAM(gpus=, nruns=, njobs=).predict(X)

# 隣接行列を作成し、2値化させ、構造方程式を定義
adjacency_matrix = pd.DataFrame(nx.to_numpy_array(sam_matrix), columns=X.columns, index=X.columns)
adjacency_matrix_01 = adjacency_matrix.applymap(lambda x: 1 if x > 0.9 else 0)
adjacency_matrix_01
equations = []
for row in adjacency_matrix_01.index:
    equation = f"{row} ~ "
    parents = adjacency_matrix_01.columns[adjacency_matrix_01.loc[row] == 1]
    equation += " + ".join(parents)
    if parents.size != 0:
        equations.append(equation)


# semopyのモデルの作成
mod = sm.Model("\n".join(equations))
res = mod.fit(X)

# 通常のプロット
sm.semplot(mod, "breast_canser.png")

image.png

参考用なので構造を複雑化しないように2値化の際に、閾値をかなり上げています。

image.png

この部分のみを注目していきましょう。

実際に改造してみる

以下のコマンドでsite-packagesのディレクトリパスを確認します。

pip show semopy

site-packages以下のsemopy/plot.pyを開くと、実際に動いているプログラムを表示できます。

そこの110~120行目くらいに観測変数のノードを追加する部分がありますので、そこで色付けを行いましょう。
まず、以下の形式でnode_colorsという引数にて色付け指示を受け取るように決めます。

{
    "ノード名1": "色1",
    "ノード名2": "色2"
}

そのうえで以下の処理を追加します。

if obs in node_colors.keys():
    g.node(obs, label=obs, color=node_colors[obs], style="bold")
else:
    g.node(obs, label=obs, style="bold")

これでノードへの色付けができました。

つぎにエッジへの色付けをします。これは新しく引数を作る必要はなく、node_colorsを用いて実装します。

if ((rval in node_colors.keys()) and lval in (node_colors.keys())) and \
    (node_colors[rval] == node_colors[lval]):
    g.edge(rval, lval, label=label, color=node_colors[rval],
           arrowsize=str(arrowsize), style="bold")
else:
    g.edge(rval, lval, label=label,
           arrowsize=str(arrowsize), style="bold")

処理の内容としては以下の条件を満たす場合にノードを同じ色を設定しています。

  • 指定されたノードとの間にエッジが存在する
  • ノード同士で同じ色を指定している

これで色付け部分の実装は終わりです。実際にnode_colorsを指定して呼び出すと鹿のようになります。

sm.semplot(mod, "breast_canser.png", node_colors={"worst_fractal_dimension": "red",
                                                  "worst_compactness": "red"})

image.png

ついでに矢印の大きさも変えてみる

気づいた人もいると思いますが、先ほどのプログラムでarrow_sizeという引数も追加しています。
ここの値を設定すると、グラフ内の矢印の大きさも変更できます。

sm.semplot(mod, "breast_canser.png", node_colors={"worst_fractal_dimension": "red",
                                                  "worst_compactness": "red"},
           arrowsize=2)

image.png

ただ、大きさは矢印の頭が大きくなるだけですので、逆に見づらいかもしれません。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?