LoginSignup
15
12

More than 3 years have passed since last update.

GANを使った「統計的因果探索」モデル SAM(Structural Agnostic Modelling)をTitanicデータセットに適用してみた

Last updated at Posted at 2020-12-30

( 参考にした書籍 )

  • 小川 雄太郎(著)『つくりながら学ぶ Pythojによる因果分析 因果推論・因果探索の実践入門』(マイナビ)

RDMS(リレーショナル・データベース)に格納されている表形式のデータセットを読み込んで、次の情報を浮かび上がらせることができる手法があります。その手法は、「統計的因果探索」モデルと呼ばれています。

  1. 表中の各列(カラム)間の変数について、
  2. どの変数(どの列のデータ)とどの変数(どの列のデータ)の間に
  3. どちらが「原因」で、どちらが「結果」の「因果関係」の矢印が存在するのか

「統計的因果推定」モデルと「統計的因果探索」モデルとしては、回帰分析を用いたものや、決定木を用いて反実仮想を行うモデルなど、複数の手法がありますが、その中に、次の2つがあります。

  • LiNGAM (Linear Non-Gaussian Acyclic Model
  • SAM (Structural Agnostic Modelling)

後者のSAMモデルは、GAN(敵対的生成ネットワーク)を用いて「因果探索」を行うモデルです。

なお、深層学習(深層ニューラル・ネットワークモデル)を用いて「因果探索」を行うモデルとしては、SAMの他に、次の2つがあります。

1. Yue Yu, Jie Chen, Tian Gao, Mo Yu, DAG-GNN: DAG Structure Learning with Graph Neural Networks, ICML2019
( 実装コード )
(GitHub) fishmoon1234 / DAG-GNN

スクリーンショット 2021-01-02 1.00.11.png

スクリーンショット 2021-01-02 1.00.29.png

2. Shengyu Zhu, Ignavier Ng, Zhitang Chen, Causal Discovery with Reinforcement Learning, ICLR 2020
( 実装コード )
(GitHub) github.com/MichelDeudon/neural-combinatorial-optimization-rl-tensorflow

スクリーンショット 2021-01-02 0.59.17.png

これらの因果探索モデルについては、稿を改めて内容に立ち入り、実装コードを動かしてみたいと思います。

この記事では、SAM (Structural Agnostic Modelling)モデルを、R言語の統計解析のサンプルコードでもお馴染みの「Titanicデータ」に対して、適用してみました。

Google Colaboratoryで実行したコードは、この記事で後ほど、全行を掲載します。
@YuyaOmoriさんの記事を参考に、実行したものになります。

なお、Google Colaboratoryで上記の記事のコードを実行したところ、一部エラーが発生したところがありました。その辺りは、以下の記事に修正ポイントを記載しています。

得られた結果は、以下です。

( 最終的に得られた因果関係の係数行列 )

# ネットワーク構造(5回の平均)
final_m = sum(m_list) / len(m_list)
print(final_m)
[[0.00 0.62 0.21 0.14 0.51 0.35 0.00 0.00 0.00 0.04 0.14]
 [0.21 0.00 0.03 0.21 0.41 0.23 0.00 0.00 0.20 0.00 0.17]
 [0.44 0.58 0.00 0.59 0.80 0.40 0.00 0.00 0.00 0.00 0.16]
 [0.69 0.39 0.30 0.00 0.16 0.80 0.00 0.00 0.00 0.00 0.00]
 [0.22 0.60 0.12 0.21 0.00 0.40 0.00 0.00 0.00 0.00 0.00]
 [0.55 0.80 0.21 0.18 0.60 0.00 0.00 0.00 0.00 0.00 0.02]
 [0.00 0.31 0.49 0.37 0.52 0.28 0.00 0.91 0.00 0.00 0.18]
 [0.99 0.00 0.36 0.17 0.37 0.36 0.99 0.00 0.01 0.25 0.11]
 [0.09 0.00 0.52 0.53 0.05 0.43 0.00 0.00 0.00 0.40 0.39]
 [0.09 0.82 0.13 0.00 0.11 0.10 0.00 0.00 0.20 0.00 0.68]
 [0.07 0.32 0.72 0.38 0.63 0.14 0.00 0.00 0.60 0.07 0.00]]

( 閾値を設定して、因果関係の有無を[0, 1]に丸めて表現した因果関係の係数行列 )

#閾値設定して1,0にする
threshold = 0.6
np.array([[1 if j > threshold else 0 for j in i] for i in final_m])
array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]])

この係数行列は、Titanicデータセットに含まれる各列(カラム)のデータを縦横に並べた時、X番目の列のデータと、Y番目の行データの間に、因果関係がどれだけ存在しているのかを数値で表現したものです。

推定結果として出力された行列は、次のように読みます。

  • $i$行目の$j$列目の数値: モデルに入力した表形式データの$i$列目のカラム変数 => モデルに入力した表形式データの$j$列目のカラム変数の方向に、因果関係が存在する確率値(区間[0, 1]。1=100%の確信度で因果関係の存在を推定している)

  • 行列の$(i, j)$番地の数値が1に近く、行列の$(j, i)$番地の数値が0に近い場合: $i$列目のカラム変数が「原因」となって、 $j$列目のカラム変数の値が「結果」として決まる関係は高確率で存在するが、逆の方向である、$j$列目のカラム変数 が原因となり、$i$列目のカラム変数が結果として決まる関係性が生じている可能性は低い。

今回、モデルに入力したTitanicデータは、各列のカラムに次の属性のデータが格納されています。

df.head()
Survived Pclass Age SibSp Parch Fare Sex_female Sex_male Embarked_C Embarked_Q Embarked_S
0 0 3 22.0 1 0 7.2500 0 1 0 0 1
1 1 1 38.0 1 0 71.2833 1 0 1 0 0
2 1 3 26.0 0 0 7.9250 1 0 0 0 1
3 1 1 35.0 1 0 53.1000 1 0 0 0 1
4 0 3 35.0 0 0 8.0500 0 1 0 0 1

列と行の同じ番号(番地)には、Titanicデータセット中の同じ「列(カラム)」の変量データが入ります。
そのため、「X番目の列のデータと、X番目の行のデータ」は、自分が自分自身に対して、どのような影響(因果関係)を持つのかを意味します。

SAMモデルは、それぞれのデータは、自分自身とは因果関係を持たないという「有向非循環グラフ」(Directed acyclic graph: DAG)を想定するため、「X番目の列のデータと、X番目の行のデータ」の値は、ゼロになります。

上記の因果関係の係数行列は、対角成分がすべてゼロ(0)になっていることが確認できます。

SAMモデルとは?

表形式のデータセットについて、どの列のデータとどの列のデータが、原因と結果の関係にあるのかを、0と1の2値の離散値で表現する$N×N$行列($N$はデータセットに含まれる列(カラム)の数)を、GAN(敵対的生成ネットワーク)モデルGeneratorモデルに生成させる手法です。

2020年10月26日にArxiv.に提出された次の論文で提案されたものです。

GANモデルの枠組みを採用していますので、識別役(Discriminator)のネットワークモデルが、G役のネットワークモデルが生成した因果関係行列が、入力されたデータセットの特徴をとらえた新しい(データセットに含まれていない)データを生成しているのか、データセットに含まれているデータを吐き出しているのか__を見分けます。

Gは、自分(G)が作った(データセットに含まれる実データに似せた)架空のデータが、Dに(それと)見抜かれないほど高い精巧さで生成できるようになることが期待されます。

このとき、Gは、データセットのデータ生成過程(Data generating process)を確率過程として、かなり高い精度で捉えられている状態が期待されます。

( 生成器Gに因果行列が埋め込まれている )

ところで、Gは、内部に、$N×N$行列($N$はデータセットに含まれる列(カラム)の数)を埋め込まれています。

Gは、GとNとのせめぎ合いというGANのネットワーク構造で学習を繰り返す中で、「データセットのデータ生成過程」(Data generating process)を正確に認識していくのと歩調を合わせて、データセットに内在する各データ間の「因果関係行列」も正しい[0, 1]の数値を埋め込むことに成功していくことが期待されます。

生成器 Gの内部に、$N×N$行列が埋め込まれている点が、一般的なGANのネットワーク構造と大きく異なる点です。

スクリーンショット 2021-01-02 0.33.33.png

因果行列は、論文中では、structural gateと表現されています。

スクリーンショット 2021-01-02 0.37.02.png

( 損失関数の特徴 )

SAMモデルは、あるカラムのデータ変数が、自分が原因となって、自分自身の値が(結果として)変わる「循環構造」(ループ・フィードバック)が含まれていない因果関係を想定しています。このような因果構造を、有向非循環グラフ(Directed Acyclic Graphと呼びます。

ここで重要なことは、DAGで表現される因果関係を持つ多変量データは、各カラムの変数の並び順序を適切に配置すると、次の因果行列で表現されます。

  • 対角要素がゼロである(これは、自分が原因となって、自分自身の値が(結果として)変わる「循環構造」が存在しないという制約を満たすため)

  • 下三角行列になる(対角要素より上側の要素がすべてゼロになる)

生成器Gが生み出す因果行列が、上記の制約を満たす行列になるように導くために、NO TEARS(Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning)と呼ばれる損失関数を採用します。

その結果、SAMモデルは、GANモデルとして具備すべき損失関数の他に、NO TEARSという別の損失関数も備えていたネットワーク構造で、モデル学習を行う設計になっています。

スクリーンショット 2021-01-02 0.33.49.png

詳しくは、以下の論文を参照してください。

Abstract

Estimating the structure of directed acyclic graphs (DAGs, also known as Bayesian networks) is a challenging problem since the search space of DAGs is combinatorial and scales superexponentially with the number of nodes. Existing approaches rely on various local heuristics for enforcing the acyclicity constraint.

In this paper, we introduce a fundamentally different strategy: we formulate the structure learning problem as a purely continuous optimization problem over real matrices that avoids this combinatorial constraint entirely.

This is achieved by a novel characterization of acyclicity that is not only smooth but also exact.

The resulting problem can be efficiently solved by standard numerical algorithms, which also makes implementation effortless. The proposed method outperforms existing
ones, without imposing any structural assumptions on the graph such as bounded treewidth or in-degree.


因果探索モデルについて

この2つの手法を含めて、統計的因果推論と因果探索を行ういくつかの代表的な方法は、以下の書籍で解説されています。

この本では、LiNGAMSAMの他に、「ある施策や出来事が、別の変量(変数)に何らかの影響を及ぼしたかどうか?」(ある施策や出来事が「原因」となって、別の変量(変数)の値が「結果」として、変化したであろうか)を推定する際に、選択することのできるいくつかの手法が、取り上げられています。

  • 線形回帰モデル
  • 決定木モデルとRandomForestモデル

ここで、「もしある施策や出来事が存在しなかったならば、発生条件下で観測されたあるデータセットの値は、どうなっていたであろうか?」という反実仮想のシナリオ下でのデータ比較を可能にするために、「線形回帰モデル」と「決定木/ RandomForestモデル」では、次の方法を取り入れています。

【 「ダミー変数」の導入 】

「ある施策」や「ある出来事」が発生した場合のデータセットと、発生しなかった場合のデーセットを用意して、「ある施策/ 出来事」の発生有無を[0, 1]の2値で表現した「ダミー変数」を作って、モデルに入力する。この「ダミー変数」は、計量経済学の時系列データ解析では、お馴染みの手法です。

その他、この本のp.116以降では、Doubly Robust Learningという手法が解説されています。
その手法を(RandomForeestモデルで)Pythonで動かしているp.117以降に掲載されているコードでは、以下の工夫が取り入れられています。

1) 「ある施策」や「ある出来事」が発生した下で観測されたデータで学習させた決定木モデルと、発生しなかった場合に観測されたデータセットで学習させた決定木モデルを、それぞれ別々に学習する。
2) その後、発生ありで学習させたモデルに「発生なしのデータ」を入力し、さらに、発生なしで学習させたモデルに「発生ありのデータ」を入力し、両者を比較する。

以下のコードの最後の次の部分です。

df_1["ITE"] = Y_1 - M_0.predict(df_1[["X"]])

( ・・・省略・・・ )

df_0["ITE"] = M_1.predict(df_0[["X"]]) - Y_0

なお、ITEとは、Individual Treatment Effectを意味すると、同書p.117で解説されています。

対象となる各個人が、「ある施策や出来事」を経験した場合と、しなかった場合で、ある値についてどれだけ差が生じていたのかを示す「ある施策や出来事」の「処置効果」を意味します。

現実に観測値としてデータがあるのは、ある個人が「ある施策や出来事」を経験した場合か、経験していない場合か、いずれか片方のシナリオのもとで観測されたデータです。

そのため、上記のコードで、「反実仮想」を行うことで、この「処置効果」を推定しています。

同書p.117より転載
# ランダムフォレストモデルを作成
from sklearn.ensemble import RandomForestRegressor

# 集団を2つに分ける
df_0 = df[df.Z == 0.0] # 介入を受けていない集団
df_1 = df[df.Z == 1.0] # 介入を受けた集団

# 介入を受けていないモデル
M_0 = RandomForestRegressor(max_depth=3)
M_0.fit(df_0[["X"]], df_0[["Y"]])

# 介入を受けたモデル
M_1 = RandomForestRegressor(max_depth=3)
M_1.fit(df_1[["X"]], df_1[["Y"]]

( ・・・省略・・・ )

# 処置群
Y_1 = M_1.predict(df_1[["X"]]) + (df_1["Y"] - M_1.predict(df_1[["X"]]))/ g_x.predict_proba(df_1[["X"]])[:, 1] # [:, 1]はZ=1側の確率
df_1["ITE"] = Y_1 - M_0.predict(df_1[["X"]])

# 非処置群
Y_0 ~ M_0.predict(df_0[["X"]]) + (df_0["Y"] - M_0.predict(df_0[["X"]])) / g_x.predict_proba(df_0[["X"]])[:, 0] # [:, 0]はZ=0側の確率
df_0["ITE"] = M_1.predict(df_0[["X"]]) - Y_0

上記で、「介入」とは、ある人の上司が、ある研修を受講しているか、していないか(Z==0: 受講しなかった , Z ==1: 受講した)を意味します。

なお、この本は、冒頭で記載されている通り、表形式のデータセットに対する統計的因果推論・因果探索を行う手法を紹介していますが、時系列データセットを取り扱う手法は取り上げていません。

時系列の多変量データセットについて、「どの時点」の「どの変量」のデータが、別の「どの時点」の別の「どの変量」のデータに対して、タイムラグを伴った影響関係(因果関係)を持っているのかは、VAR(ベクトル自己回帰モデル)や、Grangerの因果テストがあります。

また、ある時間期間における時系列変量間の関係構造が、ある時点を境に、変化したのか、しないのかを統計学的に推定するChow-testと呼ばれる手法が、知られています。(Chow-testについては、坂野 「Chow 検定統計量と制約付き最小二乗推定量」 早稲田商学第 413・414 合併号 2007年12月が分かりやすい)

異常値検知の分野では、深層学習モデルを用いて、ある時点以降、入力(観測)されるデータの「内部構造」=「特徴量」が変化したという「異変」を捉える試みが広く行われており、深層学習モデルが、高い精度を達成する能力を秘めていることが知られつつあります。

時系列因果推定モデルとしては、この後に言及する「時系列データを非線形力学系として捉え、変数間の作用や近未来予測を実行する手法」@kumalphaさんとして、Empirical Dynamic Modeling (EDM)があります。

( 参考 )

時系列因果分析の領域でも、GANを採用したSAMモデルのように、深層学習を取り入れた手法が大きな効力を発揮するのかどうかは、また別の記事を立てて、取り上げてみたいと思います。ここでは、以下の論文があることについてだけ、添えておきます。

LiNGAMとSAM:両手法の限界(制約)

LiNGAMとSAM(Structural Agnostic Modelling)は、変数間の因果関係として、線形関数を想定するモデルです。

そのため、LiNGAMとSAMは、変数間にもしも何らかの因果関係が存在する場合、その関係が必ず「線形」関係(線形関数でモデル化される)にあるという制約を持ちます。

両手法の限界(制約)を乗り越える手法:変量間の「非線形」の因果関係を推定可能

変量間の「非線形」の因果関係を推定できる手法として、変数間に、非線形の力学モデルで記述できる関係の存在を想定するCross-mapping algorithmという方法があります。
(非線形力学モデルということで、カオス力学モデルが登場します)

Qiitaには、この手法について次の記事があります。

Convergent Cross Mapping (CCM)

Empirical Dynamic Modeling (EDM)の一種。EDMは時系列データを非線形力学系として捉え、変数間の作用や近未来予測を実行する手法。数式を仮定せず、観測値から系の状態空間を再構成してモデルフリーに解析をおこなう。Simplex projectionやS-mapの考え方を複数の変数の交互の予測に拡張した手法で、相互作用のある非線形力学系ならばTakens' theoremにより、お互いの時間遅れ埋め込みの情報を使って相手の時間遅れ埋め込みを再構成できる(自分の時系列のみから相手の時系列を予測できる)という特徴を利用している。

EDMは変数間の因果性を明らかにするためにも使用される.

因果性検定: Cross-mapping algorithm
単変数から再構成された系のマニフォールドはもともとの系への写像を得る,という考えに基づく手法
変数 M1,M2M1,M2 について,Cross-mappingの結果が "収束" すれば同じ系に属する (因果的に結びついている) と捉えられる (Cross-mapping Convergent: CCM) .ここで言う > "収束" とは,ライブラリ長の増加に伴ってcross-mapping能力 (ρρ) が増加することを指す.

代表的な方法としては,下記2つの検定がともに有意だった場合,収束が有意であると判定される

  1. Kenddallの$ττ$検定
  2. Fisherの$ΔρZΔρZ$検定

(留意点) データの定常性の確認

なお、線形回帰分析に始まり、時系列データを取り扱う際は、「データが定常過程であること」を前提に置いている手法が多くあります。「定常過程」とは、Wikipediaにある通り、時間が進行していく中で、時間とともに、そのデータ値の平均値と分散値が変化しないデータであるという性質です。

時間とともにデータの平均と分散が変化してしまう時系列データは非定常なデータですので、そのまま、他の時系列データを相手方に選んで、相関分析や回帰分析を行ってしまうと、デタラメな結果が出力されます。

R言語やPython言語に実装されているADF検定などを使うことで、データの定常性(非定常性)を統計学的に推定することができます。データ非定常だった場合は、そのデータを1期前のデータと(自己)差分をとったり、対数($log$)をとることで、非定常なデータを定常過程のデータに変換することができます。そのような方法で、データを定常データに変換した上で、相関分析や回帰分析などを行うのが、時系列データ解析のお作法になります。

この辺りは、時系列データ解析や実証計量経済分析学の教科書を開くと、必ず数式を含めて丁寧に解説がされています。

LiNGAMとSAM:両モデルの解説記事

LinGAMモデルに関しては、Qiitaにタグがあります。

次のQiita記事があります。

両モデルは、それぞれ、2006年と2020年に論文の形で、提案されました。

( 論文 )

( LiNGAMモデル )

2006年3月に、初稿論文がJournal of Machine Learning Researchに掲載された、比較的新しい論文です。

Submitted 3/06; Revised 7/06; Published 10/06

Abstract

In recent years, several methods have been proposed for the discovery of causal structure from
non-experimental data. Such methods make various assumptions on the data generating process
to facilitate its identification from purely observational data. Continuing this line of research, we
show how to discover the complete causal structure of continuous-valued data, under the assumptions that (a) the data generating process is linear, (b) there are no unobserved confounders, and (c)
disturbance variables have non-Gaussian distributions of non-zero variances. The solution relies on
the use of the statistical method known as independent component analysis, and does not require
any pre-specified time-ordering of the variables. We provide a complete Matlab package for performing this LiNGAM analysis (short for Linear Non-Gaussian Acyclic Model), and demonstrate
the effectiveness of the method using artificially generated data and real-world data.

( SAMモデル )

2020年10月26日にArxiv.に提出された新しい論文です。

Abstract

A new causal discovery method, Structural Agnostic Modeling (SAM), is presented in this paper. Leveraging both conditional independencies and distributional asymmetries in the data, SAM aims to find the underlying causal structure from observational data. The approach is based on a game between different players estimating each variable distribution conditionally to the others as a neural net, and an adversary aimed at discriminating the overall joint conditional distribution, and that of the original data. A learning criterion combining distribution estimation, sparsity and acyclicity
constraints is used to enforce the end-to-end optimization of the graph structure and parameters through stochastic gradient descent. Besides a theoretical analysis of the approach in the large sample limit, SAM is extensively experimentally validated on synthetic and real data.

Google ColaboratoryのGPU環境で、Titanicデータセットに適用してみる。

今回、@YuyaOmoriさんの記事を参考に、Google Colaboratoryの無料GPUを使って、上記のjupyter notebookファイルを実行してみました。

なお、Google Colaboratoryで上記の記事のコードを実行したところ、一部エラーが発生したところがありました。その辺りは、以下の記事に修正ポイントを記載しています。

( GitHubリポジトリ )

( jupyter notebookファイル )

プログラム実行前の設定など

!nvidia-smi
Wed Dec 30 05:37:55 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
import os
!cd("./drive/My Drive/causal_inference/")
/bin/bash: -c: line 0: syntax error near unexpected token `"./drive/My Drive/causal_inference/"'
/bin/bash: -c: line 0: `cd("./drive/My Drive/causal_inference/")'
# PyTorchのバージョンを下げる
!pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.4.0+cu92
[?25l  Downloading https://download.pytorch.org/whl/cu92/torch-1.4.0%2Bcu92-cp36-cp36m-linux_x86_64.whl (640.5MB)
[K     |████████████████████████████████| 640.6MB 27kB/s 
[?25hCollecting torchvision==0.5.0+cu92
[?25l  Downloading https://download.pytorch.org/whl/cu92/torchvision-0.5.0%2Bcu92-cp36-cp36m-linux_x86_64.whl (3.9MB)
[K     |████████████████████████████████| 4.0MB 59.7MB/s 
[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.15.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (7.0.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0+cu92) (1.19.4)
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.4.0+cu92 torchvision-0.5.0+cu92
import torch 
print(torch.__version__)  # 元は1.5.0+cu101、versionを1.4に下げた
1.4.0+cu92
# 乱数のシードを設定
import random
import numpy as np

np.random.seed(1234)
random.seed(1234)

# 使用するパッケージ(ライブラリと関数)を定義
# 標準正規分布の生成用
from numpy.random import *

# グラフの描画用
import matplotlib.pyplot as plt

# その他
import pandas as pd

# シグモイド関数をimport
from scipy.special import expit

データの作成

#タイタニックデータ読み込み
df = pd.read_csv('titanic.csv')
#不要な列・欠損が多い列を削除し、カテゴリを数値に変換
df = df.drop('PassengerId',axis=1).drop('Name',axis=1).drop('Ticket',axis=1).drop('Cabin',axis=1)
df = pd.get_dummies(df)
df.head()
Survived Pclass Age SibSp Parch Fare Sex_female Sex_male Embarked_C Embarked_Q Embarked_S
0 0 3 22.0 1 0 7.2500 0 1 0 0 1
1 1 1 38.0 1 0 71.2833 1 0 1 0 0
2 1 3 26.0 0 0 7.9250 1 0 0 0 1
3 1 1 35.0 1 0 53.1000 1 0 0 0 1
4 0 3 35.0 0 0 8.0500 0 1 0 0 1
df.shape
(891, 11)
df.isnull().sum()
Survived        0
Pclass          0
Age           177
SibSp           0
Parch           0
Fare            0
Sex_female      0
Sex_male        0
Embarked_C      0
Embarked_Q      0
Embarked_S      0
dtype: int64
df['Age'] = df['Age'].fillna(df['Age'].mean())
df.isnull().sum()
Survived      0
Pclass        0
Age           0
SibSp         0
Parch         0
Fare          0
Sex_female    0
Sex_male      0
Embarked_C    0
Embarked_Q    0
Embarked_S    0
dtype: int64

SAMによる推論を実施

!pip install cdt==0.5.18
Collecting cdt==0.5.18
[?25l  Downloading https://files.pythonhosted.org/packages/97/29/144be44af187c8a2af63ceb205c38ca11787589f532cdf76517333d92d90/cdt-0.5.18-py3-none-any.whl (917kB)
[K     |████████████████████████████████| 921kB 20.2MB/s 
[?25hCollecting skrebate
  Downloading https://files.pythonhosted.org/packages/d3/8a/969e619753c299b4d3943808ef5f7eb6587d3cb78c93dcbcc3e4ce269f89/skrebate-0.61.tar.gz
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.19.4)
Requirement already satisfied: networkx in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.5)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.0.0)
Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.1.5)
Requirement already satisfied: statsmodels in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.10.2)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (1.4.1)
Collecting GPUtil
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (2.23.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (4.41.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from cdt==0.5.18) (0.22.2.post1)
Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx->cdt==0.5.18) (4.4.2)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas->cdt==0.5.18) (2.8.1)
Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from statsmodels->cdt==0.5.18) (0.5.1)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->cdt==0.5.18) (2020.12.5)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.7.3->pandas->cdt==0.5.18) (1.15.0)
Building wheels for collected packages: skrebate, GPUtil
  Building wheel for skrebate (setup.py) ... [?25l[?25hdone
  Created wheel for skrebate: filename=skrebate-0.61-cp36-none-any.whl size=29257 sha256=057c8301654fe97bb1bf1a71efa83c6e87a7ce15ec21f253601c49d68b02d162
  Stored in directory: /root/.cache/pip/wheels/ae/d8/ae/9b51d487e9d02219996d6c260255a216ef07d905b0a0b00ce3
  Building wheel for GPUtil (setup.py) ... [?25l[?25hdone
  Created wheel for GPUtil: filename=GPUtil-1.4.0-cp36-none-any.whl size=7411 sha256=28a25af00983887e1302add9caecddb3408457ed795b4d79bf6a34ab23312950
  Stored in directory: /root/.cache/pip/wheels/3d/77/07/80562de4bb0786e5ea186911a2c831fdd0018bda69beab71fd
Successfully built skrebate GPUtil
Installing collected packages: skrebate, GPUtil, cdt
Successfully installed GPUtil-1.4.0 cdt-0.5.18 skrebate-0.61

SAMの識別器Dの実装

# PyTorchから使用するものをimport
import torch
import torch.nn as nn


class SAMDiscriminator(nn.Module):
    """SAMのDiscriminatorのニューラルネットワーク
    """

    def __init__(self, nfeatures, dnh, hlayers):
        super(SAMDiscriminator, self).__init__()

        # ----------------------------------
        # ネットワークの用意
        # ----------------------------------
        self.nfeatures = nfeatures  # 入力変数の数

        layers = []
        layers.append(nn.Linear(nfeatures, dnh))
        layers.append(nn.BatchNorm1d(dnh))
        layers.append(nn.LeakyReLU(.2))

        for i in range(hlayers-1):
            layers.append(nn.Linear(dnh, dnh))
            layers.append(nn.BatchNorm1d(dnh))
            layers.append(nn.LeakyReLU(.2))

        layers.append(nn.Linear(dnh, 1))  # 最終出力

        self.layers = nn.Sequential(*layers)

        # ----------------------------------
        # maskの用意(対角成分のみ1で、他は0の行列)
        # ----------------------------------
        mask = torch.eye(nfeatures, nfeatures)  # 変数の数×変数の数の単位行列
        self.register_buffer("mask", mask.unsqueeze(0))  # 単位行列maskを保存しておく

        # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです
        # self.変数名で、以降も使用可能になります
        # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer

    def forward(self, input, obs_data=None):
        """ 順伝搬の計算
        Args:
            input (torch.Size([データ数, 観測変数の種類数])): 観測したデータ、もしくは生成されたデータ
            obs_data (torch.Size([データ数, 観測変数の種類数])):観測したデータ
        Returns:
            torch.Tensor: 観測したデータか、それとも生成されたデータかの判定結果
        """

        if obs_data is not None:
          # 生成データを識別器に入力する場合
            return [self.layers(i) for i in torch.unbind(obs_data.unsqueeze(1) * (1 - self.mask)
                                                         + input.unsqueeze(1) * self.mask, 1)]
            # 対角成分のみ生成したデータ、その他は観測データに
            # データを各変数ごとに、生成したもの、その他観測したもので混ぜて、1変数ずつ生成したものを放り込む
            # torch.unbind(x,1)はxの1次元目でテンソルをタプルに展開する
            # minibatch数が2000、観測データの変数が6種類の場合、
            # [2000,6]→[2000,6,6]→([2000,6], [2000,6], [2000,6], [2000,6], [2000,6], [2000,6])→([2000,1], [2000,1], [2000,1], [2000,1], [2000,1], [2000,1])
            # returnは[torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1], torch.Size([2000, 1]),torch.Size([2000, 1]),torch.Size([2000, 1])]

            # 注:生成した変数全種類を用いた判定はしない。
            # すなわち、生成した変数1種類と、元の観測データたちをまとめて1つにし、それが観測結果か、生成結果を判定させる

        else:
            # 観測データを識別器に入力する場合

            return self.layers(input)
            # returnは[torch.Size([2000, 1])]


    def reset_parameters(self):
        """識別器Dの重みパラメータの初期化を実施"""
        for layer in self.layers:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

SAMの生成器Gの実装

from cdt.utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D


class SAMGenerator(nn.Module):
    """SAMのGeneratorのニューラルネットワーク
    """

    def __init__(self, data_shape, nh):
        """初期化"""
        super(SAMGenerator, self).__init__()

        # ----------------------------------
        # 対角成分のみ0で、残りは1のmaskとなる変数skeletonを作成
        # ※最後の行は、全部1です
        # ----------------------------------
        nb_vars = data_shape[1]  # 変数の数
        skeleton = 1 - torch.eye(nb_vars + 1, nb_vars)

        self.register_buffer('skeleton', skeleton)

        # 注意:register_bufferはmodelのパラメータではないが、その後forwardで使う変数を登録するPyTorchのメソッドです
        # self.変数名で、以降も使用可能になります
        # https://pytorch.org/docs/stable/nn.html?highlight=register_buffer#torch.nn.Module.register_buffer

        # ----------------------------------
        # ネットワークの用意
        # ----------------------------------
        # 入力層(SAMの形での全結合層) 
        self.input_layer = Linear3D(
            (nb_vars, nb_vars + 1, nh))  # nhは中間層のニューロン数
        # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L289

        # 中間層
        layers = []
        # 2次元を1次元に変換してバッチノーマライゼーションするモジュール
        layers.append(ChannelBatchNorm1d(nb_vars, nh))
        layers.append(nn.Tanh())
        self.layers = nn.Sequential(*layers)

        # ChannelBatchNorm1d
        # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L130

        # 出力層(再度、SAMの形での全結合層)
        self.output_layer = Linear3D((nb_vars, nh, 1))

    def forward(self, data, noise, adj_matrix, drawn_neurons=None):
        """ 順伝搬の計算
        Args:
            data (torch.Tensor): 観測データ
            noise (torch.Tensor): データ生成用のノイズ
            adj_matrix (torch.Tensor): 因果関係を示す因果構造マトリクスM
            drawn_neurons (torch.Tensor): Linear3Dの複雑さを制御する複雑さマトリクスZ
        Returns:
            torch.Tensor: 生成されたデータ
        """

        # 入力層
        x = self.input_layer(data, noise, adj_matrix *
                             self.skeleton)  # Linear3D

        # 中間層(バッチノーマライゼーションとTanh)
        x = self.layers(x)

        # 出力層
        output = self.output_layer(
            x, noise=None, adj_matrix=drawn_neurons)  # Linear3D

        return output.squeeze(2)

    def reset_parameters(self):
        """重みパラメータの初期化を実施"""

        self.input_layer.reset_parameters()
        self.output_layer.reset_parameters()

        for layer in self.layers:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

Detecting 1 CUDA device(s).

SAMの誤差関数

# ネットワークを示す因果構造マトリクスMがDAG(有向非循環グラフ)になるように加える損失

def notears_constr(adj_m, max_pow=None):
    """No Tears constraint for binary adjacency matrixes. 
    Args:
        adj_m (array-like): Adjacency matrix of the graph
        max_pow (int): maximum value to which the infinite sum is to be computed.
           defaults to the shape of the adjacency_matrix
    Returns:
        np.ndarray or torch.Tensor: Scalar value of the loss with the type
            depending on the input.
    参考:https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/loss.py#L215
    """
    m_exp = [adj_m]
    if max_pow is None:
        max_pow = adj_m.shape[1]
    while(m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
        m_exp.append(m_exp[-1] @ adj_m/len(m_exp))

    return sum([i.diag().sum() for idx, i in enumerate(m_exp)])

SAMの学習を実施する関数

from sklearn.preprocessing import scale
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm


def run_SAM(in_data, lr_gen, lr_disc, lambda1, lambda2, hlayers, nh, dnh, train_epochs, test_epochs, device):
    '''SAMの学習を実行する関数'''

    # ---------------------------------------------------
    # 入力データの前処理
    # ---------------------------------------------------
    list_nodes = list(in_data.columns)  # 入力データの列名のリスト
    data = scale(in_data[list_nodes].values)  # 入力データの正規化
    nb_var = len(list_nodes)  # 入力データの数 = d
    data = data.astype('float32')  # 入力データをfloat32型に
    data = torch.from_numpy(data).to(device)  # 入力データをPyTorchのテンソルに
    rows, cols = data.size()  # rowsはデータ数、colsは変数の数

    # ---------------------------------------------------
    # DataLoaderの作成(バッチサイズは全データ)
    # ---------------------------------------------------
    batch_size = rows  # 入力データ全てを使用したミニバッチ学習とする
    data_iterator = DataLoader(data, batch_size=batch_size,
                               shuffle=True, drop_last=True)
    # 注意:引数のdrop_lastはdataをbatch_sizeで取り出していったときに最後に余ったものは使用しない設定

    # ---------------------------------------------------
    # 【Generator】ネットワークの生成とパラメータの初期化
    # cols:入力変数の数、nhは中間ニューロンの数、hlayersは中間層の数
    # neuron_samplerは、Functional gatesの変数zを学習するネットワーク
    # graph_samplerは、Structual gatesの変数aを学習するネットワーク
    # ---------------------------------------------------
    sam = SAMGenerator((batch_size, cols), nh).to(device)  # 生成器G
    graph_sampler = MatrixSampler(nb_var, mask=None, gumbel=False).to(
        device)  # 因果構造マトリクスMを作るネットワーク
    neuron_sampler = MatrixSampler((nh, nb_var), mask=False, gumbel=True).to(
        device)  # 複雑さマトリクスZを作るネットワーク

    # 注意:MatrixSamplerはGumbel-Softmaxを使用し、0か1を出力させるニューラルネットワーク
    # SAMの著者らの実装モジュール、MatrixSamplerを使用
    # https://github.com/FenTechSolutions/CausalDiscoveryToolbox/blob/32200779ab9b63762be3a24a2147cff09ba2bb72/cdt/utils/torch.py#L212

    # 重みパラメータの初期化
    sam.reset_parameters()
    graph_sampler.weights.data.fill_(2)

    # ---------------------------------------------------
    # 【Discriminator】ネットワークの生成とパラメータの初期化
    # cols:入力変数の数、dnhは中間ニューロンの数、hlayersは中間層の数。
    # ---------------------------------------------------
    discriminator = SAMDiscriminator(cols, dnh, hlayers).to(device)
    discriminator.reset_parameters()  # 重みパラメータの初期化

    # ---------------------------------------------------
    # 最適化の設定
    # ---------------------------------------------------
    # 生成器

    g_optimizer = optim.Adam(sam.parameters(), lr=lr_gen)
    graph_optimizer = optim.Adam(graph_sampler.parameters(), lr=lr_gen)
    neuron_optimizer = optim.Adam(neuron_sampler.parameters(), lr=lr_gen)

    # 識別器
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr_disc)

    # 損失関数
    criterion = nn.BCEWithLogitsLoss()
    # nn.BCEWithLogitsLoss()は、binary cross entropy with Logistic function
    # https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss

    # 損失関数のDAGに関する制約の設定パラメータ
    dagstart = 0.5
    dagpenalization_increase = 0.001*10

    # ---------------------------------------------------
    # forward計算、および損失関数の計算に使用する変数を用意
    # ---------------------------------------------------
    _true = torch.ones(1).to(device)
    _false = torch.zeros(1).to(device)

    noise = torch.randn(batch_size, nb_var).to(device)  # 生成器Gで使用する生成ノイズ
    noise_row = torch.ones(1, nb_var).to(device)

    output = torch.zeros(nb_var, nb_var).to(device)  # 求まった隣接行列
    output_loss = torch.zeros(1, 1).to(device)

    # ---------------------------------------------------
    # forwardの計算で、ネットワークを学習させる
    # ---------------------------------------------------
    pbar = tqdm(range(train_epochs + test_epochs),mininterval=1)  # 進捗(progressive bar)の表示

    for epoch in pbar:
        for i_batch, batch in enumerate(data_iterator):

            # 最適化を初期化
            g_optimizer.zero_grad()
            graph_optimizer.zero_grad()
            neuron_optimizer.zero_grad()
            d_optimizer.zero_grad()

            # 因果構造マトリクスM(drawn_graph)と複雑さマトリクスZ(drawn_neurons)をMatrixSamplerから取得
            drawn_graph = graph_sampler()
            drawn_neurons = neuron_sampler()
            # (drawn_graph)のサイズは、torch.Size([nb_var, nb_var])。 出力値は0か1
            # (drawn_neurons)のサイズは、torch.Size([nh, nb_var])。 出力値は0か1

            # ノイズをリセットし、生成器Gで疑似データを生成
            noise.normal_()
            generated_variables = sam(data=batch, noise=noise,
                                      adj_matrix=torch.cat(
                                          [drawn_graph, noise_row], 0),
                                      drawn_neurons=drawn_neurons)

            # 識別器Dで判定
            # 観測変数のリスト[]で、各torch.Size([data数, 1])が求まる
            disc_vars_d = discriminator(generated_variables.detach(), batch)
            # 観測変数のリスト[] で、各torch.Size([data数, 1])が求まる
            disc_vars_g = discriminator(generated_variables, batch)
            true_vars_disc = discriminator(batch)  # torch.Size([data数, 1])が求まる

            # 損失関数の計算(DCGAN)
            disc_loss = sum([criterion(gen, _false.expand_as(gen)) for gen in disc_vars_d]) / nb_var \
                + criterion(true_vars_disc, _true.expand_as(true_vars_disc))

            gen_loss = sum([criterion(gen,
                                      _true.expand_as(gen))
                            for gen in disc_vars_g])

            # 損失の計算(SAM論文のオリジナルのfgan)
            #disc_loss = sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_d]) / nb_var - torch.mean(true_vars_disc)
            #gen_loss = -sum([torch.mean(torch.exp(gen - 1)) for gen in disc_vars_g])

            # 識別器Dのバックプロパゲーションとパラメータの更新
            if epoch < train_epochs:
                disc_loss.backward()
                d_optimizer.step()

            # 生成器のGの損失の計算の残り(マトリクスの複雑さとDAGのNO TEAR)
            struc_loss = lambda1 / batch_size*drawn_graph.sum()     # Mのloss
            func_loss = lambda2 / batch_size*drawn_neurons.sum()   # Aのloss

            regul_loss = struc_loss + func_loss

            if epoch <= train_epochs * dagstart:
                # epochが基準前のときは、DAGになるようにMへのNO TEARSの制限はかけない
                loss = gen_loss + regul_loss

            else:
                # epochが基準後のときは、DAGになるようにNO TEARSの制限をかける
                filters = graph_sampler.get_proba()  # マトリクスMの要素を取得(ただし、0,1ではなく、1の確率)
                dag_constraint = notears_constr(filters*filters)  # NO TERARの計算

                # 徐々に線形にDAGの正則を強くする
                loss = gen_loss + regul_loss + \
                    ((epoch - train_epochs * dagstart) *
                     dagpenalization_increase) * dag_constraint

            if epoch >= train_epochs:
                # testのepochの場合、結果を取得
                output.add_(filters.data)
                output_loss.add_(gen_loss.data)
            else:
                # trainのepochの場合、生成器Gのバックプロパゲーションと更新
                # retain_graph=Trueにすることで、以降3つのstep()が実行できる
                loss.backward(retain_graph=True)
                g_optimizer.step()
                graph_optimizer.step()
                neuron_optimizer.step()

            # 進捗の表示
            if epoch % 50 == 0:
                pbar.set_postfix(gen=gen_loss.item()/cols,
                                 disc=disc_loss.item(),
                                 regul_loss=regul_loss.item(),
                                 tot=loss.item())

    return output.cpu().numpy()/test_epochs, output_loss.cpu().numpy()/test_epochs/cols  # Mと損失を出力

GPUの使用可能を確認

画面上部のメニュー ランタイム > ランタイムのタイプを変更 で、 ノートブックの設定 を開く

ハードウェアアクセラレータに GPU を選択し、 保存 する

# GPUの使用確認:True or False
torch.cuda.is_available()

True

SAMの学習を実施

# numpyの出力を小数点2桁に
np.set_printoptions(precision=2, floatmode='fixed', suppress=True)

# 因果探索の結果を格納するリスト
m_list = []
loss_list = []

for i in range(5):
    m, loss = run_SAM(in_data=df, lr_gen=0.01*0.5,
                      lr_disc=0.01*0.5*2,
                      #lambda1=0.01, lambda2=1e-05,
                      lambda1=5.0*20, lambda2=0.005*20,
                      hlayers=2,
                      nh=200, dnh=200,
                      train_epochs=10000,
                      test_epochs=1000,
                      device='cuda:0')

    print(loss)
    print(m)

    m_list.append(m)
    loss_list.append(loss)

100%|██████████| 11000/11000 [06:04<00:00, 30.21it/s, disc=0.565, gen=6.27, regul_loss=2.84, tot=208]
  0%|          | 0/11000 [00:00<?, ?it/s, disc=1.4, gen=0.687, regul_loss=12.1, tot=19.7]

[[6.23]]
[[0.00 0.01 0.00 0.64 0.01 0.02 0.00 0.00 0.00 0.00 0.00]
 [0.94 0.00 0.00 0.99 0.01 0.06 0.00 0.00 0.99 0.00 0.00]
 [0.82 0.97 0.00 0.92 0.99 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.01 0.01 0.00 0.00 0.00 0.02 0.00 0.00 0.00 0.00 0.00]
 [0.81 1.00 0.00 0.99 0.00 0.99 0.00 0.00 0.00 0.00 0.00]
 [0.82 0.98 0.04 0.85 0.00 0.00 0.00 0.00 0.00 0.00 0.02]
 [0.00 0.00 0.00 0.00 0.91 0.00 0.00 0.99 0.00 0.00 0.02]
 [0.97 0.00 0.85 0.02 0.00 0.85 1.00 0.00 0.00 0.00 0.00]
 [0.00 0.00 0.00 0.00 0.00 0.01 0.00 0.00 0.00 0.00 0.01]
 [0.00 0.44 0.00 0.00 0.00 0.00 0.00 0.00 0.07 0.00 0.99]
 [0.00 0.00 0.96 0.99 0.98 0.00 0.00 0.00 0.99 0.00 0.00]]


100%|██████████| 11000/11000 [06:02<00:00, 30.32it/s, disc=0.617, gen=6.7, regul_loss=2.85, tot=213]
  0%|          | 0/11000 [00:00<?, ?it/s, disc=1.43, gen=0.618, regul_loss=12.4, tot=19.1]

[[7.02]]
[[0.00 0.57 0.05 0.04 0.99 0.72 0.00 0.00 0.00 0.18 0.01]
 [0.07 0.00 0.15 0.02 0.99 0.05 0.00 0.00 0.00 0.01 0.00]
 [0.11 0.13 0.00 0.06 1.00 0.04 0.00 0.00 0.00 0.00 0.00]
 [0.54 0.96 0.48 0.00 0.04 0.99 0.00 0.00 0.00 0.00 0.00]
 [0.00 0.01 0.01 0.03 0.00 0.01 0.00 0.00 0.00 0.00 0.00]
 [0.01 1.00 0.00 0.03 0.99 0.00 0.00 0.00 0.00 0.00 0.01]
 [0.00 0.99 0.78 0.00 0.76 0.69 0.00 1.00 0.00 0.00 0.00]
 [0.99 0.00 0.00 0.00 0.00 0.01 0.99 0.00 0.02 0.00 0.54]
 [0.39 0.00 0.84 0.95 0.26 0.60 0.00 0.00 0.00 0.99 0.01]
 [0.00 0.81 0.00 0.00 0.55 0.49 0.00 0.00 0.00 0.00 0.04]
 [0.32 0.60 0.90 0.00 0.00 0.00 0.00 0.00 0.90 0.17 0.00]]


100%|██████████| 11000/11000 [06:03<00:00, 30.29it/s, disc=0.488, gen=7.03, regul_loss=3.29, tot=219]
  0%|          | 0/11000 [00:00<?, ?it/s, disc=1.47, gen=0.922, regul_loss=12.4, tot=22.5]

[[6.97]]
[[0.00 0.58 0.00 0.01 0.59 0.01 0.00 0.00 0.00 0.01 0.71]
 [0.03 0.00 0.01 0.01 1.00 0.06 0.00 0.00 0.00 0.00 0.83]
 [0.82 0.90 0.00 1.00 1.00 0.95 0.00 0.00 0.00 0.00 0.78]
 [0.98 0.97 0.04 0.00 0.77 1.00 0.00 0.00 0.00 0.00 0.01]
 [0.01 0.03 0.00 0.01 0.00 0.01 0.00 0.00 0.00 0.00 0.00]
 [0.99 0.99 0.03 0.01 1.00 0.00 0.00 0.00 0.00 0.01 0.06]
 [0.00 0.57 0.93 0.00 0.00 0.02 0.00 1.00 0.01 0.00 0.86]
 [0.99 0.00 0.00 0.84 0.99 0.01 0.99 0.00 0.00 0.57 0.00]
 [0.00 0.00 0.87 0.86 0.00 0.00 0.00 0.00 0.00 0.00 0.94]
 [0.00 0.97 0.00 0.00 0.00 0.00 0.00 0.00 0.89 0.00 0.99]
 [0.00 0.00 0.01 0.00 0.18 0.00 0.00 0.00 0.01 0.00 0.00]]


100%|██████████| 11000/11000 [06:09<00:00, 29.78it/s, disc=0.837, gen=3.75, regul_loss=3.41, tot=86.7]
  0%|          | 0/11000 [00:00<?, ?it/s, disc=1.5, gen=0.99, regul_loss=12.2, tot=23.1]

[[7.31]]
[[0.00 0.98 0.99 0.01 0.00 0.97 0.00 0.00 0.01 0.00 0.00]
 [0.00 0.00 0.00 0.02 0.01 0.98 0.00 0.00 0.00 0.00 0.00]
 [0.00 0.02 0.00 0.01 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
 [0.99 0.01 0.96 0.00 0.00 0.98 0.00 0.01 0.00 0.00 0.00]
 [0.25 0.99 0.59 0.00 0.00 0.99 0.00 0.01 0.01 0.00 0.00]
 [0.01 0.02 1.00 0.01 0.01 0.00 0.00 0.01 0.00 0.00 0.00]
 [0.00 0.00 0.00 0.94 0.94 0.69 0.00 0.56 0.01 0.00 0.00]
 [0.99 0.00 0.95 0.00 0.00 0.00 1.00 0.00 0.03 0.66 0.00]
 [0.00 0.00 0.00 0.00 0.00 0.79 0.00 0.00 0.00 1.00 1.00]
 [0.28 0.98 0.66 0.00 0.00 0.00 0.00 0.00 0.04 0.00 0.41]
 [0.01 0.97 0.84 0.91 0.98 0.72 0.00 0.00 0.11 0.06 0.00]]


100%|██████████| 11000/11000 [06:07<00:00, 29.91it/s, disc=0.455, gen=5.35, regul_loss=3.19, tot=203]

[[6.55]]
[[0.00 0.95 0.00 0.01 0.97 0.02 0.00 0.00 0.00 0.03 0.00]
 [0.01 0.00 0.00 0.02 0.02 0.02 0.00 0.00 0.00 0.00 0.00]
 [0.47 0.90 0.00 0.98 1.00 0.99 0.00 0.00 0.00 0.00 0.00]
 [0.93 0.00 0.01 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00]
 [0.01 0.98 0.00 0.01 0.00 0.01 0.00 0.00 0.00 0.00 0.00]
 [0.94 0.99 0.01 0.01 0.97 0.00 0.00 0.00 0.00 0.01 0.00]
 [0.00 0.00 0.75 0.92 0.00 0.00 0.00 1.00 0.01 0.01 0.00]
 [0.98 0.00 0.00 0.00 0.88 0.92 1.00 0.00 0.00 0.00 0.00]
 [0.05 0.00 0.90 0.82 0.00 0.76 0.00 0.00 0.00 0.01 0.02]
 [0.15 0.88 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.99]
 [0.00 0.00 0.91 0.00 0.99 0.00 0.00 0.00 0.99 0.09 0.00]]
# ネットワーク構造(5回の平均)
final_m = sum(m_list) / len(m_list)
print(final_m)
[[0.00 0.62 0.21 0.14 0.51 0.35 0.00 0.00 0.00 0.04 0.14]
 [0.21 0.00 0.03 0.21 0.41 0.23 0.00 0.00 0.20 0.00 0.17]
 [0.44 0.58 0.00 0.59 0.80 0.40 0.00 0.00 0.00 0.00 0.16]
 [0.69 0.39 0.30 0.00 0.16 0.80 0.00 0.00 0.00 0.00 0.00]
 [0.22 0.60 0.12 0.21 0.00 0.40 0.00 0.00 0.00 0.00 0.00]
 [0.55 0.80 0.21 0.18 0.60 0.00 0.00 0.00 0.00 0.00 0.02]
 [0.00 0.31 0.49 0.37 0.52 0.28 0.00 0.91 0.00 0.00 0.18]
 [0.99 0.00 0.36 0.17 0.37 0.36 0.99 0.00 0.01 0.25 0.11]
 [0.09 0.00 0.52 0.53 0.05 0.43 0.00 0.00 0.00 0.40 0.39]
 [0.09 0.82 0.13 0.00 0.11 0.10 0.00 0.00 0.20 0.00 0.68]
 [0.07 0.32 0.72 0.38 0.63 0.14 0.00 0.00 0.60 0.07 0.00]]
#閾値設定して1,0にする
threshold = 0.6
np.array([[1 if j > threshold else 0 for j in i] for i in final_m])
array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]])
df.head()
Survived Pclass Age SibSp Parch Fare Sex_female Sex_male Embarked_C Embarked_Q Embarked_S
0 0 3 22.0 1 0 7.2500 0 1 0 0 1
1 1 1 38.0 1 0 71.2833 1 0 1 0 0
2 1 3 26.0 0 0 7.9250 1 0 0 0 1
3 1 1 35.0 1 0 53.1000 1 0 0 0 1
4 0 3 35.0 0 0 8.0500 0 1 0 0 1
15
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
12